diff --git a/NOTICE.md b/NOTICE.md new file mode 100644 index 00000000..c6d90cae --- /dev/null +++ b/NOTICE.md @@ -0,0 +1,3 @@ +This project is licensed under the MIT License. However, some files in this repository are under the Apache 2.0 License as noted below: + +- backbone/vit.py (Apache 2.0 License) diff --git a/backbone/EfficientNet.py b/backbone/EfficientNet.py index ddaa3dd5..3521fe31 100644 --- a/backbone/EfficientNet.py +++ b/backbone/EfficientNet.py @@ -155,7 +155,7 @@ def round_filters(filters, global_params): multiplier = global_params.width_coefficient if not multiplier: return filters - + divisor = global_params.depth_divisor min_depth = global_params.min_depth filters *= multiplier @@ -244,6 +244,7 @@ def get_same_padding_conv2d(image_size=None): else: return partial(Conv2dStaticSamePadding, image_size=image_size) + GlobalParams = collections.namedtuple('GlobalParams', [ 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 'data_format', 'num_classes', 'width_coefficient', 'depth_coefficient', 'depth_divisor', @@ -848,7 +849,7 @@ def forward(self, inputs, returnt='out'): x = self.classifier(x) if returnt == 'out': return x - elif returnt == 'all': + elif returnt == 'full': return (x, feats) raise NotImplementedError("Unknown return type") diff --git a/backbone/MNISTMLP.py b/backbone/MNISTMLP.py index d9e98f67..d39bfb3b 100644 --- a/backbone/MNISTMLP.py +++ b/backbone/MNISTMLP.py @@ -68,7 +68,7 @@ def forward(self, x: torch.Tensor, returnt='out') -> torch.Tensor: if returnt == 'out': return out - elif returnt == 'all': + elif returnt == 'full': return (out, feats) raise NotImplementedError("Unknown return type") diff --git a/backbone/ResNet18_PNN.py b/backbone/ResNet18_PNN.py index 0795299e..f24b4c4a 100644 --- a/backbone/ResNet18_PNN.py +++ b/backbone/ResNet18_PNN.py @@ -10,7 +10,7 @@ import torch.nn.functional as F from torch.nn.functional import avg_pool2d, relu -from backbone.ResNet18 import BasicBlock, ResNet, conv3x3 +from backbone.ResNetBlock import BasicBlock, ResNet, conv3x3 from backbone.utils.modules import AlphaModule, ListModule diff --git a/backbone/ResNet18.py b/backbone/ResNetBlock.py similarity index 94% rename from backbone/ResNet18.py rename to backbone/ResNetBlock.py index 36537292..56d8f65c 100644 --- a/backbone/ResNet18.py +++ b/backbone/ResNetBlock.py @@ -110,6 +110,8 @@ def __init__(self, block: BasicBlock, num_blocks: List[int], self.layer4 = self._make_layer(block, nf * 8, num_blocks[3], stride=2) self.classifier = nn.Linear(nf * 8 * block.expansion, num_classes) + self.feature_dim = nf * 8 * block.expansion + def to(self, device, **kwargs): self.device = device return super().to(device, **kwargs) @@ -185,7 +187,7 @@ def forward(self, x: torch.Tensor, returnt='out') -> torch.Tensor: out_4 if not self.return_prerelu else self.layer4[-1].prerelu ] - raise NotImplementedError("Unknown return type. Must be in ['out', 'features', 'both', 'all'] but got {}".format(returnt)) + raise NotImplementedError("Unknown return type. Must be in ['out', 'features', 'both', 'full'] but got {}".format(returnt)) def resnet18(nclasses: int, nf: int = 64) -> ResNet: @@ -200,3 +202,17 @@ def resnet18(nclasses: int, nf: int = 64) -> ResNet: ResNet network """ return ResNet(BasicBlock, [2, 2, 2, 2], nclasses, nf) + + +def resnet34(nclasses: int, nf: int = 64) -> ResNet: + """ + Instantiates a ResNet34 network. + + Args: + nclasses: number of output classes + nf: number of filters + + Returns: + ResNet network + """ + return ResNet(BasicBlock, [3, 4, 6, 3], nclasses, nf) diff --git a/backbone/ResNet50.py b/backbone/ResNetBottleneck.py similarity index 98% rename from backbone/ResNet50.py rename to backbone/ResNetBottleneck.py index e3c9bb33..84ac54a2 100644 --- a/backbone/ResNet50.py +++ b/backbone/ResNetBottleneck.py @@ -146,6 +146,8 @@ def __init__( self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.classifier = nn.Linear(512 * block.expansion, num_classes) + self.feature_dim = 512 * block.expansion + for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_( @@ -242,7 +244,7 @@ def forward(self, x: Tensor, returnt="out") -> Tensor: elif returnt == 'both': return (out, feature) - raise NotImplementedError("Unknown return type. Must be in ['out', 'features', 'both', 'all'] but got {}".format(returnt)) + raise NotImplementedError("Unknown return type. Must be in ['out', 'features', 'both', 'full'] but got {}".format(returnt)) def set_grad_filter(self, filter_s: str, enable: bool) -> None: negative_mode = filter_s[0] == '~' diff --git a/backbone/utils/layers.py b/backbone/utils/layers.py new file mode 100644 index 00000000..55547581 --- /dev/null +++ b/backbone/utils/layers.py @@ -0,0 +1,167 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LoRALayer(): + def __init__( + self, + lora_dropout: float, + ): + # Optional dropout + if lora_dropout > 0.: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = lambda x: x + + +class LoRALinear(nn.Linear, LoRALayer): + + def __init__( + self, + in_features: int, + out_features: int, + lora_dropout: float = 0., + fan_in_fan_out: bool = False, + **kwargs + ): + nn.Linear.__init__(self, in_features, out_features, **kwargs) + LoRALayer.__init__(self, lora_dropout=lora_dropout) + + self.fan_in_fan_out = fan_in_fan_out + self.weight.requires_grad = False + self.reset_parameters() + + if fan_in_fan_out: + self.weight.data = self.weight.data.transpose(0, 1) + + def reset_parameters(self): + nn.Linear.reset_parameters(self) + + def forward(self, x: torch.Tensor, AB: dict = None): + + def T(w): + return w.transpose(1, 2) if self.fan_in_fan_out else w + + result = F.linear(x, T(self.weight), bias=self.bias) + + if AB is not None: + A = None + if isinstance(AB, dict): + B = AB['B'] + A = AB.get('A') + else: + B = AB + if A is not None: + return result + (B @ (A @ x.transpose(1, 2).unsqueeze(1))).sum(1).transpose(1, 2) + return result + (B @ x.transpose(1, 2).unsqueeze(1)).sum(1).transpose(1, 2) + + return result + + +class ClipLinear(nn.Linear, LoRALayer): + + def __init__( + self, + in_features: int, + out_features: int, + lora_dropout: float = 0., + fan_in_fan_out: bool = False, + **kwargs + ): + nn.Linear.__init__(self, in_features, out_features, **kwargs) + LoRALayer.__init__(self, lora_dropout=lora_dropout) + + self.fan_in_fan_out = fan_in_fan_out + self.weight.requires_grad = False + self.reset_parameters() + + if fan_in_fan_out: + self.weight.data = self.weight.data.transpose(0, 1) + + def reset_parameters(self): + nn.Linear.reset_parameters(self) + + def forward(self, x: torch.Tensor, AB: dict = None): + + def T(w): + return w.transpose(1, 2) if self.fan_in_fan_out else w + + result = F.linear(x, T(self.weight), bias=self.bias) + + if AB is not None: + A = None + if isinstance(AB, dict): + B = AB['B'] + A = AB.get('A') + else: + B = AB + if A is not None: + res = (B @ (A @ torch.permute(x, (1, 2, 0)).unsqueeze(1))).sum(1) + return result + torch.permute(res, (2, 0, 1)) + res = (B @ torch.permute(x, (1, 2, 0)).unsqueeze(1)).sum(1) + return result + torch.permute(res, (2, 0, 1)) + + return result + + +class IncrementalClassifier(nn.Module): + + def __init__(self, embed_dim: int, nb_classes: int): + """ + Incremental classifier for continual learning. + + Args: + embed_dim: int, dimension of the input features. + nb_classes: int, number of classes to classify. + """ + + super().__init__() + + self.embed_dim = embed_dim + + heads = [nn.Linear(embed_dim, nb_classes, bias=True)] + + self.heads = nn.ModuleList(heads) + self.old_state_dict = None + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def update(self, nb_classes: int, freeze_old=True): + """ + Add a new head to the classifier. + + Args: + nb_classes, number of classes to add. + freeze_old: bool, whether to freeze the old heads. + """ + + _fc = nn.Linear(self.embed_dim, nb_classes, bias=True).to(self.get_device()) + + nn.init.trunc_normal_(_fc.weight, std=.02) + nn.init.constant_(_fc.bias, 0) + + if freeze_old: + for param in self.heads.parameters(): + param.requires_grad = False + + self.heads.append(_fc) + + def forward(self, x: torch.Tensor): + """ + Forward pass. + + Compute the logits for each head and concatenate them. + + Args: + x: torch.Tensor, input features. + """ + return torch.cat([h(x) for h in self.heads], dim=1) diff --git a/backbone/utils/vit_default_cfg.py b/backbone/utils/vit_default_cfg.py new file mode 100644 index 00000000..6770cc93 --- /dev/null +++ b/backbone/utils/vit_default_cfg.py @@ -0,0 +1,457 @@ +from typing import Optional, List +try: + from timm.models._pretrained import generate_default_cfgs +except BaseException: + from timm.models._registry import generate_default_cfgs +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ + OPENAI_CLIP_MEAN, OPENAI_CLIP_STD + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'vit_base_patch16_224_in21k_fn_in1k_old': _cfg( + url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz'), + + # re-finetuned augreg 21k FT on in1k weights + 'vit_base_patch16_224.augreg2_in21k_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'vit_base_patch16_384.augreg2_in21k_ft_in1k': _cfg(), + 'vit_base_patch8_224.augreg2_in21k_ft_in1k': _cfg( + hf_hub_id='timm/'), + + # How to train your ViT (augreg) weights, pretrained on 21k FT on in1k + 'vit_tiny_patch16_224.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_tiny_patch16_384.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_small_patch32_224.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_small_patch32_384.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_small_patch16_224.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_small_patch16_384.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch32_224.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_base_patch32_384.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch16_224.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_base_patch16_384.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch8_224.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_large_patch16_224.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_large_patch16_384.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + + # patch models (weights from official Google JAX impl) pretrained on in21k FT on in1k + 'vit_base_patch16_224.orig_in21k_ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', + hf_hub_id='timm/'), + 'vit_base_patch16_384.orig_in21k_ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth', + hf_hub_id='timm/', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_large_patch32_384.orig_in21k_ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', + hf_hub_id='timm/', + input_size=(3, 384, 384), crop_pct=1.0), + + # How to train your ViT (augreg) weights trained on in1k only + 'vit_small_patch16_224.augreg_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_small_patch16_384.augreg_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch32_224.augreg_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_base_patch32_384.augreg_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch16_224.augreg_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_base_patch16_384.augreg_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + + 'vit_large_patch14_224.untrained': _cfg(url=''), + 'vit_huge_patch14_224.untrained': _cfg(url=''), + 'vit_giant_patch14_224.untrained': _cfg(url=''), + 'vit_gigantic_patch14_224.untrained': _cfg(url=''), + + # patch models, imagenet21k (weights from official Google JAX impl) + 'vit_large_patch32_224.orig_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', + hf_hub_id='timm/', + num_classes=21843), + 'vit_huge_patch14_224.orig_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', + hf_hub_id='timm/', + custom_load=True, num_classes=21843), + + # How to train your ViT (augreg) weights, pretrained on in21k + 'vit_tiny_patch16_224.augreg_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', + hf_hub_id='timm/', + custom_load=True, num_classes=21843), + 'vit_small_patch32_224.augreg_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', + hf_hub_id='timm/', + custom_load=True, num_classes=21843), + 'vit_small_patch16_224.augreg_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', + hf_hub_id='timm/', + custom_load=True, num_classes=21843), + 'vit_base_patch32_224.augreg_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz', + hf_hub_id='timm/', + custom_load=True, num_classes=21843), + 'vit_base_patch16_224.augreg_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', + hf_hub_id='timm/', + custom_load=True, num_classes=21843), + 'vit_base_patch8_224.augreg_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', + hf_hub_id='timm/', + custom_load=True, num_classes=21843), + 'vit_large_patch16_224.augreg_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', + hf_hub_id='timm/', + custom_load=True, num_classes=21843), + + # SAM trained models (https://arxiv.org/abs/2106.01548) + 'vit_base_patch32_224.sam': _cfg( + url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz', custom_load=True, + hf_hub_id='timm/'), + 'vit_base_patch16_224.sam': _cfg( + url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz', custom_load=True, + hf_hub_id='timm/'), + + # DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only) + 'vit_small_patch16_224.dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth', + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_small_patch8_224.dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth', + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_base_patch16_224.dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth', + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_base_patch8_224.dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth', + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + + # ViT ImageNet-21K-P pretraining by MILL + 'vit_base_patch16_224_miil.in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth', + hf_hub_id='timm/', + mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221), + 'vit_base_patch16_224_miil.in21k_ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_1k_miil_84_4-2deb18e3.pth', + hf_hub_id='timm/', + mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'), + + # Custom timm variants + 'vit_base_patch16_rpn_224.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth', + hf_hub_id='timm/'), + 'vit_medium_patch16_gap_240.in12k': _cfg( + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95, num_classes=11821), + 'vit_medium_patch16_gap_256.in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_medium_patch16_gap_384.in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), crop_pct=0.95, crop_mode='squash'), + 'vit_base_patch16_gap_224': _cfg(), + + # CLIP pretrained image tower and related fine-tuned weights + 'vit_base_patch32_clip_224.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), + 'vit_base_patch32_clip_384.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)), + 'vit_base_patch32_clip_448.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 448, 448)), + 'vit_base_patch16_clip_224.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95), + 'vit_base_patch16_clip_384.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'), + 'vit_large_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0), + 'vit_large_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), + 'vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), + 'vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), + + 'vit_base_patch32_clip_224.openai_ft_in12k_in1k': _cfg( + # hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k_in1k', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), + 'vit_base_patch32_clip_384.openai_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'), + 'vit_base_patch16_clip_224.openai_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95), + 'vit_base_patch16_clip_384.openai_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'), + 'vit_large_patch14_clip_224.openai_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), + 'vit_large_patch14_clip_336.openai_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), + + 'vit_base_patch32_clip_224.laion2b_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), + 'vit_base_patch16_clip_224.laion2b_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), + 'vit_base_patch16_clip_384.laion2b_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'), + 'vit_large_patch14_clip_224.laion2b_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0), + 'vit_large_patch14_clip_336.laion2b_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), + 'vit_huge_patch14_clip_224.laion2b_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), + 'vit_huge_patch14_clip_336.laion2b_ft_in1k': _cfg( + hf_hub_id='', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), + + 'vit_base_patch32_clip_224.openai_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), + 'vit_base_patch16_clip_224.openai_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), + 'vit_base_patch16_clip_384.openai_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'), + 'vit_large_patch14_clip_224.openai_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), + + 'vit_base_patch32_clip_224.laion2b_ft_in12k': _cfg( + # hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in12k', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), + 'vit_base_patch16_clip_224.laion2b_ft_in12k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), + 'vit_large_patch14_clip_224.laion2b_ft_in12k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=11821), + 'vit_huge_patch14_clip_224.laion2b_ft_in12k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821), + + 'vit_base_patch32_clip_224.openai_ft_in12k': _cfg( + # hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), + 'vit_base_patch16_clip_224.openai_ft_in12k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), + 'vit_large_patch14_clip_224.openai_ft_in12k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821), + + 'vit_base_patch32_clip_224.laion2b': _cfg( + hf_hub_id='laion/CLIP-ViT-B-32-laion2B-s34B-b79K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), + 'vit_base_patch16_clip_224.laion2b': _cfg( + # hf_hub_id='laion/CLIP-ViT-B-16-laion2B-s34B-b88K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), + 'vit_large_patch14_clip_224.laion2b': _cfg( + hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=768), + 'vit_huge_patch14_clip_224.laion2b': _cfg( + hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), + 'vit_giant_patch14_clip_224.laion2b': _cfg( + hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), + 'vit_gigantic_patch14_clip_224.laion2b': _cfg( + hf_hub_id='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280), + + 'vit_base_patch32_clip_224.openai': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), + 'vit_base_patch16_clip_224.openai': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), + 'vit_large_patch14_clip_224.openai': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), + + # experimental (may be removed) + 'vit_base_patch32_plus_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95), + 'vit_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95), + 'vit_small_patch16_36x1_224': _cfg(url=''), + 'vit_small_patch16_18x2_224': _cfg(url=''), + 'vit_base_patch16_18x2_224': _cfg(url=''), + + # EVA fine-tuned weights from MAE style MIM - EVA-CLIP target pretrain + # https://github.com/baaivision/EVA/blob/7ecf2c0a370d97967e86d047d7af9188f78d2df3/eva/README.md#eva-l-learning-better-mim-representations-from-eva-clip + 'eva_large_patch14_196.in22k_ft_in22k_in1k': _cfg( + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_21k_to_1k_ft_88p6.pt', + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 196, 196), crop_pct=1.0), + 'eva_large_patch14_336.in22k_ft_in22k_in1k': _cfg( + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_21k_to_1k_ft_89p2.pt', + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), + 'eva_large_patch14_196.in22k_ft_in1k': _cfg( + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_1k_ft_88p0.pt', + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 196, 196), crop_pct=1.0), + 'eva_large_patch14_336.in22k_ft_in1k': _cfg( + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_1k_ft_88p65.pt', + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), + + 'flexivit_small.1200ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_small.600ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_600ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_small.300ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_300ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + + 'flexivit_base.1200ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_base.600ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_600ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_base.300ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_300ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_base.1000ep_in21k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_1000ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843), + 'flexivit_base.300ep_in21k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_300ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843), + + 'flexivit_large.1200ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_large.600ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_600ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_large.300ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_300ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + + 'flexivit_base.patch16_in21k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/vit_b16_i21k_300ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843), + 'flexivit_base.patch30_in21k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/vit_b30_i21k_300ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843), +}) diff --git a/backbone/vit.py b/backbone/vit.py new file mode 100644 index 00000000..9e53fde9 --- /dev/null +++ b/backbone/vit.py @@ -0,0 +1,758 @@ +""" Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +------------------------------------------------------------------------------- + +Cloned and trimmed version of timm.models.vision_transformer.py +Here for STABLE reference. + +Check out https://github.com/pprp/timm/blob/master/timm/models/vision_transformer.py for the original file. + +The following is the original docstring of the file. + +------------------------------------------------------------------------------- + +Vision Transformer (ViT) in PyTorch + +A PyTorch implement of Vision Transformers as described in: + +'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' + - https://arxiv.org/abs/2010.11929 + +`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` + - https://arxiv.org/abs/2106.10270 + +`FlexiViT: One Model for All Patch Sizes` + - https://arxiv.org/abs/2212.08013 + +The official jax code is released and available at + * https://github.com/google-research/vision_transformer + * https://github.com/google-research/big_vision + +Acknowledgments: + * The paper authors for releasing code and weights, thanks! + * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch + * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT + * Bert reference code checks against Huggingface Transformers and Tensorflow Bert + +Hacked together by / Copyright 2020, Ross Wightman +""" + +import logging +import math +from collections import OrderedDict +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \ + resample_abs_pos_embed +from timm.models._builder import build_model_with_cfg +from timm.models._manipulate import named_apply + +from backbone.utils.layers import LoRALinear, IncrementalClassifier + +from itertools import repeat +import collections.abc + +from backbone import MammothBackbone + +__all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to this + +_logger = logging.getLogger(__name__) + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + return parse + + +to_2tuple = _ntuple(2) + + +class Attention(nn.Module): + """ + Attention layer as used in Vision Transformer. + Adapted to support LoRA-style parameters. + + Args: + dim: Number of input channels + num_heads: Number of attention heads + qkv_bias: If True, add a learnable bias to q, k, v + attn_drop: Dropout rate for attention weights + proj_drop: Dropout rate after the final projection + """ + + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = LoRALinear(dim, dim * 3, 0., bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = LoRALinear(dim, dim, 0.) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, AB: dict = None, **kwargs): + """ + Forward pass of the attention layer. + Supports `AB` for LoRA-style parameters (checkout docs for `VisionTransformer.forward`). + + Args: + x: Input tensor + AB: Dictionary containing LoRA-style parameters for the layer + """ + + B, N, C = x.shape + + AB_qkv = None + + if AB is not None: + AB_qkv = AB.get("qkv") + + qkv = self.qkv(x, AB_qkv) + qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + # NOTE: flash attention is less debuggable than the original. Use the commented code below if in trouble. + x = F.scaled_dot_product_attention(q, k, v, scale=self.scale, dropout_p=self.attn_drop.p) + # attn = (q @ k.transpose(-2, -1)) * self.scale + # attn = attn.softmax(dim=-1) + # attn = self.attn_drop(attn) + # x = (attn @ v) + + x = x.transpose(1, 2).reshape(B, N, C) + + AB_proj = None + + if AB is not None: + AB_proj = AB.get("proj") + + x = self.proj(x, AB_proj) + x = self.proj_drop(x) + + return x + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Mlp(nn.Module): + """ + MLP as used in Vision Transformer, MLP-Mixer and related networks. + Adapted to support LoRA-style parameters. + """ + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0., + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + + assert use_conv is False + + self.fc1 = LoRALinear(in_features, hidden_features, bias=bias[0], lora_dropout=0.) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + self.fc2 = LoRALinear(hidden_features, out_features, bias=bias[1], lora_dropout=0.) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x: torch.Tensor, AB: dict = None, **kwargs): + """ + Forward pass of the MLP layer. + Supports `AB` for LoRA-style parameters (checkout docs for `VisionTransformer.forward`). + + Args: + x: Input tensor + AB: Dictionary containing LoRA-style parameters for the layer + """ + AB_fc1 = None + AB_fc2 = None + + if AB is not None: + AB_fc1 = AB.get("fc1") + AB_fc2 = AB.get("fc2") + + x = self.fc1(x, AB_fc1) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x, AB_fc2) + x = self.drop2(x) + + return x + + +class Block(nn.Module): + + def __init__( + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + drop=0., + attn_drop=0., + init_values=None, + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + attn_layer=Attention, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = attn_layer(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x, **kwargs): + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), **kwargs))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x), **kwargs))) + return x + + +class VisionTransformer(MammothBackbone): + """ Vision Transformer + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + """ + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + global_pool='token', + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=True, + init_values=None, + class_token=True, + no_embed_class=False, + pre_norm=False, + fc_norm=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + weight_init='', + embed_layer=PatchEmbed, + norm_layer=None, + act_layer=None, + block_fn=Block, + attn_layer=Attention, + args=None + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + global_pool (str): type of global pooling for final sequence (default: 'token') + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + init_values: (float): layer-scale init values + class_token (bool): use class token + fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + weight_init (str): weight init scheme + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + act_layer: (nn.Module): MLP activation layer + block_fn: (nn.Module): transformer block + attn_layer: (nn.Module): attention layer + args: (Namespace): optional command-line arguments + """ + super().__init__() + assert global_pool in ('', 'avg', 'token') + assert class_token or global_pool != 'token' + use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + self.act_layer = act_layer or nn.GELU + + self.attn_layer = attn_layer + self.norm_layer = norm_layer + self.num_heads = num_heads + self.weight_init = weight_init + self.class_token = class_token + self.num_classes = num_classes + self.global_pool = global_pool + self.feature_dim = self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_prefix_tokens = 1 if class_token else 0 + self.no_embed_class = no_embed_class + self.mlp_ratio = mlp_ratio + self.args = args + self.init_values = init_values + self.qkv_bias = qkv_bias + self.attn_drop_rate = attn_drop_rate + self.depth = depth + + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens + self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) + self.pos_drop = nn.Dropout(p=drop_rate) + self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() + + self.dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.Sequential(*[ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + init_values=init_values, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=self.dpr[i], + norm_layer=norm_layer, + act_layer=self.act_layer, + attn_layer=attn_layer + ) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() + + # Classifier Head + self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + self.head = IncrementalClassifier(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + if weight_init != 'skip': + self.init_weights(weight_init) + + self.embed_dim = embed_dim + + def init_weights(self, mode=''): + assert mode in ('jax', 'jax_nlhb', 'moco', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. + trunc_normal_(self.pos_embed, std=.02) + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(get_init_weights_vit(mode, head_bias), self) + + def _pos_embed(self, x): + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then concat + x = x + self.pos_embed + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.pos_embed + return self.pos_drop(x) + + def forward_features(self, x: torch.Tensor, AB={}, return_all=False): + """ + Compute the forward pass of ViT (features only). + Can take in an additional argument `AB`, which is a dictionary containing LoRA-style parameters for each block. + + Args: + x: input tensor + AB: dictionary containing LoRA-style parameters for each block + return_all: whether to return all intermediate features + + Returns: + features for each patch + """ + int_features = [] + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.norm_pre(x) + # NOTE: grad checkpointing was removed from the original timm impl + for idx, blk in enumerate(self.blocks): + AB_blk = AB.get(idx) + if AB_blk is not None: + x = blk(x, AB_blk) + else: + x = blk(x) + if return_all: + int_features.append(x.clone()) + x = self.norm(x) + + if return_all: + int_features.append(x.clone()) + return int_features + return x + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False): + """ + Compute the forward pass of ViT (head only). + Expects input of shape [batch_size, num_patches, embed_dim]. + + Args: + x: input tensor + pre_logits: whether to return the pre-logits (pooled features) or the final class scores + + Returns: + output tensor with shape [batch_size, num_classes] if `pre_logits` is False, else [batch_size, embed_dim] + """ + if self.global_pool: + x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = self.fc_norm(x) + return x if pre_logits else self.head(x) + + def forward(self, x: torch.Tensor, AB: dict = {}, returnt='out'): + """ + Compute the forward pass of ViT. + Can take in an additional argument `AB`, which is a dictionary containing LoRA-style parameters for each block. + + `AB` can contain + - a single value for each block (e.g. `AB = {0: {"qkv": torch.Tensor(...)}, 1: {"qkv": torch.Tensor(...)}, ...}`) + - a dictionary for each block with a single key `B` (e.g. `AB = {0: {"qkv": {"B": torch.Tensor(...)}}}`) + - a dictionary for each block with both `A` and `B` keys of LoRA parameters (e.g. `AB = {0: {"qkv": {"A": torch.Tensor(...), "B": torch.Tensor(...)}}}`) + + Supported keys for each block are `qkv`, `proj`, `fc1`, `fc2`. + + NOTE: The values of `AB` are **summed** with the weights of the corresponding block. + + Args: + x: input tensor + AB: dictionary containing LoRA-style parameters for each block + returnt: return type (a string among `out`, `features`, `both`, or `full`) + + Returns: + output tensor + """ + assert returnt in ('out', 'features', 'both', 'full') + + x = self.forward_features(x, AB, return_all=returnt == 'full') + if returnt == 'full': + all_features = x + x = x[-1] + feats = self.forward_head(x, pre_logits=True) + + if returnt == 'features': + return feats + + out = self.head(feats) + + if returnt == 'both': + return out, feats + elif returnt == 'full': + return out, all_features + return out + + def get_params(self, discard_classifier=False) -> torch.Tensor: + """ + Returns all the parameters concatenated in a single tensor. + + Returns: + parameters tensor + """ + params = [] + for kk, pp in list(self.named_parameters()): + if not discard_classifier or not 'head' in kk: + params.append(pp.view(-1)) + return torch.cat(params) + + def get_grads(self, discard_classifier=False) -> torch.Tensor: + """ + Returns all the gradients concatenated in a single tensor. + + Returns: + gradients tensor + """ + grads = [] + for kk, pp in list(self.named_parameters()): + if not discard_classifier or not 'head' in kk: + grads.append(pp.grad.view(-1)) + return torch.cat(grads) + + +def init_weights_vit_timm(module: nn.Module, name: str = ''): + """ ViT weight initialization, original timm impl (for reproducibility) """ + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.): + """ ViT weight initialization, matching JAX (Flax) impl """ + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def init_weights_vit_moco(module: nn.Module, name: str = ''): + """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """ + if isinstance(module, nn.Linear): + if 'qkv' in name: + # treat the weights of Q, K, V separately + val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1])) + nn.init.uniform_(module.weight, -val, val) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def get_init_weights_vit(mode='jax', head_bias: float = 0.): + if 'jax' in mode: + return partial(init_weights_vit_jax, head_bias=head_bias) + elif 'moco' in mode: + return init_weights_vit_moco + else: + return init_weights_vit_timm + + +def resize_pos_embed( + posemb, + posemb_new, + num_prefix_tokens=1, + gs_new=(), + interpolation='bicubic', + antialias=False, +): + """ Rescale the grid of position embeddings when loading from state_dict. + + *DEPRECATED* This function is being deprecated in favour of resample_abs_pos_embed + + Adapted from: + https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + """ + ntok_new = posemb_new.shape[1] + if num_prefix_tokens: + posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:] + ntok_new -= num_prefix_tokens + else: + posemb_prefix, posemb_grid = posemb[:, :0], posemb[0] + gs_old = int(math.sqrt(len(posemb_grid))) + if not len(gs_new): # backwards compatibility + gs_new = [int(math.sqrt(ntok_new))] * 2 + assert len(gs_new) >= 2 + _logger.info(f'Resized position embedding: {posemb.shape} ({[gs_old, gs_old]}) to {posemb_new.shape} ({gs_new}).') + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode=interpolation, antialias=antialias, align_corners=False) + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) + posemb = torch.cat([posemb_prefix, posemb_grid], dim=1) + return posemb + + +def _convert_openai_clip(state_dict, model): + out_dict = {} + swaps = [ + ('visual.', ''), ('conv1', 'patch_embed.proj'), ('positional_embedding', 'pos_embed'), + ('transformer.resblocks.', 'blocks.'), ('ln_pre', 'norm_pre'), ('ln_post', 'norm'), ('ln_', 'norm'), + ('in_proj_', 'qkv.'), ('out_proj', 'proj'), ('mlp.c_fc', 'mlp.fc1'), ('mlp.c_proj', 'mlp.fc2'), + ] + for k, v in state_dict.items(): + if not k.startswith('visual.'): + continue + for sp in swaps: + k = k.replace(sp[0], sp[1]) + + if k == 'proj': + k = 'head.weight' + v = v.transpose(0, 1) + out_dict['head.bias'] = torch.zeros(v.shape[0]) + elif k == 'class_embedding': + k = 'cls_token' + v = v.unsqueeze(0).unsqueeze(1) + elif k == 'pos_embed': + v = v.unsqueeze(0) + if v.shape[1] != model.pos_embed.shape[1]: + # To resize pos embedding when using model at different size from pretrained weights + v = resize_pos_embed( + v, + model.pos_embed, + 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1), + model.patch_embed.grid_size + ) + out_dict[k] = v + return out_dict + + +def checkpoint_filter_fn( + state_dict, + model, + adapt_layer_scale=False, + interpolation='bicubic', + antialias=True, +): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + import re + out_dict = {} + if 'model' in state_dict: + # For deit models + state_dict = state_dict['model'] + + if 'visual.class_embedding' in state_dict: + return _convert_openai_clip(state_dict, model) + + for k, v in state_dict.items(): + if 'patch_embed.proj.weight' in k: + O, I, H, W = model.patch_embed.proj.weight.shape + if len(v.shape) < 4: + # For old models that I trained prior to conv based patchification + O, I, H, W = model.patch_embed.proj.weight.shape + v = v.reshape(O, -1, H, W) + if v.shape[-1] != W or v.shape[-2] != H: + v = resample_patch_embed( + v, + (H, W), + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]: + # To resize pos embedding when using model at different size from pretrained weights + num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1) + v = resample_abs_pos_embed( + v, + new_size=model.patch_embed.grid_size, + num_prefix_tokens=num_prefix_tokens, + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + elif adapt_layer_scale and 'gamma_' in k: + # remap layer-scale gamma into sub-module (deit3 models) + k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k) + elif 'pre_logits' in k: + # NOTE representation layer removed as not used in latest 21k/1k pretrained weights + continue + out_dict[k] = v + return out_dict + + +def create_vision_transformer(variant, base_class=VisionTransformer, pretrained=False, filter_fn=checkpoint_filter_fn, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + if 'flexi' in variant: + # FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed + # interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation. + _filter_fn = partial(filter_fn, interpolation='bilinear', antialias=False) + else: + _filter_fn = filter_fn + + if variant == 'vit_base_patch16_224_in21k_fn_in1k_old': + from timm.models.helpers import resolve_pretrained_cfg + + pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None)) + pretrained_cfg.custom_load = True + + return build_model_with_cfg( + base_class, + variant, + pretrained, + pretrained_cfg=pretrained_cfg, + pretrained_filter_fn=_filter_fn, + pretrained_strict=True, + **kwargs, + ) + else: + return build_model_with_cfg( + base_class, variant, pretrained, + pretrained_filter_fn=_filter_fn, + **kwargs, + ) + + +def vit_base_patch16_224_prompt_prototype(pretrained=False, pretrain_type='in21k-ft-in1k', **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + + By default, returns a model pre-trained on ImageNet-21k. + Supports: + - Pre-train on ImageNet-21k (pretrain_type='in21k') + - Pre-train on ImageNet-21k and finetuned on ImageNet-1k (pretrain_type='in21k_old') + - Pre-train with MoCoV3 on ImageNet-21k (pretrain_type='in21k-ft-in1k') + + Args: + pretrained (bool): Load pre-trained weights. + pretrain_type (str): Type of pre-training. Default is 'in21k'. Other options are 'in21k_old' and 'in1k'. + **kwargs: Additional arguments to pass to the model. + """ + assert pretrain_type in ['in21k', 'in21k_old', 'in21k-ft-in1k'], f"Invalid pretrain_type: {pretrain_type}" + if not pretrained: + print("WARNING: creating a ViT without pre-trained weights. This is not recommended.") + + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12) + if kwargs is None: + kwargs = {} + + if pretrain_type == 'in21k_old': + model = create_vision_transformer('vit_base_patch16_224_in21k_fn_in1k_old', pretrained=pretrained, **dict(model_kwargs, **kwargs)) + elif pretrain_type == 'in21k': + model = create_vision_transformer('vit_base_patch16_224.augreg_in21k', pretrained=pretrained, **dict(model_kwargs, **kwargs)) + else: + model = create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) + return model diff --git a/datasets/deprecated/old_mnist_360.py b/datasets/deprecated/old_mnist_360.py index ed25a8f7..2620086f 100644 --- a/datasets/deprecated/old_mnist_360.py +++ b/datasets/deprecated/old_mnist_360.py @@ -104,7 +104,7 @@ def init_train_loaders(self) -> None: tmp_train_dataset.transform = transforms.Compose( [train_rotation, transforms.ToTensor()]) self.train_loaders[-1].append(create_seeded_dataloader(self.args, - tmp_train_dataset, batch_size=1, shuffle=True, num_workers=0)) + tmp_train_dataset, batch_size=1, shuffle=True, num_workers=0)) self.remaining_training_items[-1].append( tmp_train_dataset.data.shape[0]) @@ -127,7 +127,7 @@ def init_test_loaders(self) -> None: tmp_test_dataset.transform = transforms.Compose( [test_rotation, transforms.ToTensor()]) self.test_loaders.append(create_seeded_dataloader(self.args, tmp_test_dataset, - batch_size=self.args.batch_size, shuffle=True, num_workers=0)) + batch_size=self.args.batch_size, shuffle=True, num_workers=0)) def get_train_data(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ @@ -158,7 +158,7 @@ def get_train_data(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: x_train.append(i_x_train) y_train.append(i_y_train) x_train_naug.append(i_x_train_naug) - x_train, y_train, x_train_naug = torch.cat(x_train),\ + x_train, y_train, x_train_naug = torch.cat(x_train), \ torch.cat(y_train), torch.cat(x_train_naug) self.active_remaining_training_items[0] -= batch_size_0 @@ -206,10 +206,5 @@ def get_transform(): def get_denormalization_transform(): return None - @staticmethod - def get_batch_size() -> int: - return 16 - - @staticmethod - def get_minibatch_size() -> int: + def get_batch_size(self) -> int: return 16 diff --git a/datasets/mnist_360.py b/datasets/mnist_360.py index d97a2607..0f038f06 100644 --- a/datasets/mnist_360.py +++ b/datasets/mnist_360.py @@ -19,6 +19,7 @@ from datasets.utils.gcl_dataset import GCLDataset from datasets.utils.validation import get_train_val from utils.conf import base_path, create_seeded_dataloader +from datasets.utils import set_default_from_args def custom_collate_unbatch(batch) -> List[torch.Tensor]: @@ -138,7 +139,7 @@ def init_test_loaders(self) -> None: train_dataset = MyMNIST(base_path() + 'MNIST', train=True, download=True) _, test_dataset = get_train_val( - train_dataset, test_transform, 'mnist-360', val_perc=self.args.validation / 100) + train_dataset, test_transform, 'mnist-360', val_perc=self.args.validation) else: test_dataset = MNIST(base_path() + 'MNIST', train=False, download=True) @@ -316,12 +317,12 @@ def get_normalization_transform(): def get_denormalization_transform(): return None - @staticmethod - def get_batch_size() -> int: + @set_default_from_args('batch_size') + def get_batch_size(self) -> int: return 16 - @staticmethod - def get_epochs(): + @set_default_from_args('n_epochs') + def get_epochs(self): return 1 diff --git a/datasets/perm_mnist.py b/datasets/perm_mnist.py index 194cd12d..2187d5ba 100644 --- a/datasets/perm_mnist.py +++ b/datasets/perm_mnist.py @@ -16,6 +16,8 @@ from datasets.transforms.permutation import Permutation from datasets.utils.continual_dataset import ContinualDataset, store_masked_loaders from utils.conf import base_path +from datasets.utils import set_default_from_args + class MyMNIST(MNIST): """ @@ -103,10 +105,10 @@ def get_denormalization_transform(): def get_loss(): return F.cross_entropy - @staticmethod - def get_batch_size() -> int: + @set_default_from_args('batch_size') + def get_batch_size(self) -> int: return 128 - @staticmethod - def get_epochs(): + @set_default_from_args('n_epochs') + def get_epochs(self): return 1 diff --git a/datasets/rot_mnist.py b/datasets/rot_mnist.py index 3e00f4f1..a346a3bc 100644 --- a/datasets/rot_mnist.py +++ b/datasets/rot_mnist.py @@ -11,6 +11,7 @@ from datasets.transforms.rotation import Rotation from datasets.utils.continual_dataset import ContinualDataset, store_masked_loaders from utils.conf import base_path +from datasets.utils import set_default_from_args class RotatedMNIST(ContinualDataset): @@ -64,10 +65,10 @@ def get_loss(): def get_denormalization_transform(): return None - @staticmethod - def get_batch_size() -> int: + @set_default_from_args('batch_size') + def get_batch_size(self) -> int: return 128 - @staticmethod - def get_epochs(): + @set_default_from_args('n_epochs') + def get_epochs(self): return 1 diff --git a/datasets/seq_cifar10.py b/datasets/seq_cifar10.py index ce1c9d43..5705cd92 100644 --- a/datasets/seq_cifar10.py +++ b/datasets/seq_cifar10.py @@ -11,11 +11,12 @@ from PIL import Image from torchvision.datasets import CIFAR10 -from backbone.ResNet18 import resnet18 +from backbone.ResNetBlock import resnet18 from datasets.seq_tinyimagenet import base_path from datasets.transforms.denormalization import DeNormalize from datasets.utils.continual_dataset import (ContinualDataset, store_masked_loaders) +from datasets.utils import set_default_from_args class TCIFAR10(CIFAR10): @@ -96,17 +97,16 @@ class SequentialCIFAR10(ContinualDataset): transforms.ToTensor(), transforms.Normalize(MEAN, STD)]) + TEST_TRANSFORM = transforms.Compose([transforms.ToTensor(), transforms.Normalize(MEAN, STD)]) + def get_data_loaders(self) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]: """Class method that returns the train and test loaders.""" transform = self.TRANSFORM - test_transform = transforms.Compose( - [transforms.ToTensor(), self.get_normalization_transform()]) - train_dataset = MyCIFAR10(base_path() + 'CIFAR10', train=True, download=True, transform=transform) test_dataset = TCIFAR10(base_path() + 'CIFAR10', train=False, - download=True, transform=test_transform) + download=True, transform=self.TEST_TRANSFORM) train, test = store_masked_loaders(train_dataset, test_dataset, self) return train, test @@ -136,10 +136,10 @@ def get_denormalization_transform(): transform = DeNormalize(SequentialCIFAR10.MEAN, SequentialCIFAR10.STD) return transform - @staticmethod - def get_epochs(): + @set_default_from_args('n_epochs') + def get_epochs(self): return 50 - @staticmethod - def get_batch_size(): + @set_default_from_args('batch_size') + def get_batch_size(self): return 32 diff --git a/datasets/seq_cifar100.py b/datasets/seq_cifar100.py index 45d82a75..8f71959b 100644 --- a/datasets/seq_cifar100.py +++ b/datasets/seq_cifar100.py @@ -12,12 +12,12 @@ from PIL import Image from torchvision.datasets import CIFAR100 -from backbone.ResNet18 import resnet18 +from backbone.ResNetBlock import resnet18 from datasets.transforms.denormalization import DeNormalize from datasets.utils.continual_dataset import (ContinualDataset, store_masked_loaders) -# from models.utils.continual_model import ContinualModel from utils.conf import base_path +from datasets.utils import set_default_from_args class TCIFAR100(CIFAR100): @@ -142,12 +142,12 @@ def get_denormalization_transform(): transform = DeNormalize(SequentialCIFAR100.MEAN, SequentialCIFAR100.STD) return transform - @staticmethod - def get_epochs(): + @set_default_from_args('n_epochs') + def get_epochs(self): return 50 - @staticmethod - def get_batch_size(): + @set_default_from_args('batch_size') + def get_batch_size(self): return 32 @staticmethod diff --git a/datasets/seq_cifar100_224.py b/datasets/seq_cifar100_224.py index e7d7d53c..b47758ac 100644 --- a/datasets/seq_cifar100_224.py +++ b/datasets/seq_cifar100_224.py @@ -5,13 +5,14 @@ import torch import torch.nn.functional as F import torchvision.transforms as transforms -from timm import create_model +from backbone.vit import vit_base_patch16_224_prompt_prototype from datasets.seq_cifar100 import TCIFAR100, MyCIFAR100 from datasets.transforms.denormalization import DeNormalize from datasets.utils.continual_dataset import (ContinualDataset, store_masked_loaders) from utils.conf import base_path +from datasets.utils import set_default_from_args class SequentialCIFAR100224(ContinualDataset): @@ -70,12 +71,7 @@ def get_transform(): @staticmethod def get_backbone(hookme=False): - model_name = 'vit_base_patch16_224' - return create_model( - model_name, - pretrained=True, - num_classes=SequentialCIFAR100224.N_CLASSES - ) + return vit_base_patch16_224_prompt_prototype(pretrained=True, num_classes=SequentialCIFAR100224.N_CLASSES_PER_TASK * SequentialCIFAR100224.N_TASKS) @staticmethod def get_loss(): @@ -91,10 +87,10 @@ def get_denormalization_transform(): transform = DeNormalize(SequentialCIFAR100224.MEAN, SequentialCIFAR100224.STD) return transform - @staticmethod - def get_epochs(): + @set_default_from_args('n_epochs') + def get_epochs(self): return 5 - @staticmethod - def get_batch_size(): + @set_default_from_args('batch_size') + def get_batch_size(self): return 128 diff --git a/datasets/seq_cifar100_224_rs.py b/datasets/seq_cifar100_224_rs.py index 3dc78dc2..366ba966 100644 --- a/datasets/seq_cifar100_224_rs.py +++ b/datasets/seq_cifar100_224_rs.py @@ -4,12 +4,13 @@ import torch.nn.functional as F import torchvision.transforms as transforms -from backbone.ResNet50 import resnet50 +from backbone.ResNetBottleneck import resnet50 from datasets.seq_cifar100 import TCIFAR100, MyCIFAR100 from datasets.transforms.denormalization import DeNormalize from datasets.utils.continual_dataset import (ContinualDataset, store_masked_loaders) from utils.conf import base_path +from datasets.utils import set_default_from_args class SequentialCIFAR100224RS(ContinualDataset): @@ -85,10 +86,10 @@ def get_denormalization_transform(): transform = DeNormalize(SequentialCIFAR100224RS.MEAN, SequentialCIFAR100224RS.STD) return transform - @staticmethod - def get_epochs(): + @set_default_from_args('n_epochs') + def get_epochs(self): return 50 - @staticmethod - def get_batch_size(): + @set_default_from_args('batch_size') + def get_batch_size(self): return 32 diff --git a/datasets/seq_cifar10_224.py b/datasets/seq_cifar10_224.py new file mode 100644 index 00000000..51748e8f --- /dev/null +++ b/datasets/seq_cifar10_224.py @@ -0,0 +1,95 @@ +# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. +# All rights reserved. +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +import torch.nn.functional as F +import torchvision.transforms as transforms + +from backbone.vit import vit_base_patch16_224_prompt_prototype +from datasets.seq_cifar10 import TCIFAR10, MyCIFAR10 +from datasets.seq_tinyimagenet import base_path +from datasets.transforms.denormalization import DeNormalize +from datasets.utils.continual_dataset import (ContinualDataset, + store_masked_loaders) +from datasets.utils import set_default_from_args + + +class SequentialCIFAR10224(ContinualDataset): + """Sequential CIFAR10 Dataset. The images are resized to 224x224. + Version with ViT backbone. + + Args: + NAME (str): name of the dataset. + SETTING (str): setting of the dataset. + N_CLASSES_PER_TASK (int): number of classes per task. + N_TASKS (int): number of tasks. + N_CLASSES (int): number of classes. + SIZE (tuple): size of the images. + MEAN (tuple): mean of the dataset. + STD (tuple): standard deviation of the dataset. + TRANSFORM (torchvision.transforms): transformations to apply to the dataset. + """ + + NAME = 'seq-cifar10-224' + SETTING = 'class-il' + N_CLASSES_PER_TASK = 2 + N_TASKS = 5 + N_CLASSES = N_CLASSES_PER_TASK * N_TASKS + MEAN, STD = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2615) + SIZE = (224, 224) + TRANSFORM = transforms.Compose( + [transforms.Resize(224), + transforms.RandomCrop(224, padding=28), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(MEAN, STD)]) + + TEST_TRANSFORM = transforms.Compose([transforms.Resize(224), transforms.ToTensor(), transforms.Normalize(MEAN, STD)]) + + def get_data_loaders(self) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]: + """Class method that returns the train and test loaders.""" + transform = self.TRANSFORM + + train_dataset = MyCIFAR10(base_path() + 'CIFAR10', train=True, + download=True, transform=transform) + test_dataset = TCIFAR10(base_path() + 'CIFAR10', train=False, + download=True, transform=self.TEST_TRANSFORM) + + train, test = store_masked_loaders(train_dataset, test_dataset, self) + return train, test + + @staticmethod + def get_transform(): + transform = transforms.Compose( + [transforms.ToPILImage(), SequentialCIFAR10224.TRANSFORM]) + return transform + + @staticmethod + def get_backbone(): + return vit_base_patch16_224_prompt_prototype(pretrained=True, num_classes=SequentialCIFAR10224.N_CLASSES_PER_TASK * SequentialCIFAR10224.N_TASKS) + + @staticmethod + def get_loss(): + return F.cross_entropy + + @staticmethod + def get_normalization_transform(): + transform = transforms.Normalize(SequentialCIFAR10224.MEAN, SequentialCIFAR10224.STD) + return transform + + @staticmethod + def get_denormalization_transform(): + transform = DeNormalize(SequentialCIFAR10224.MEAN, SequentialCIFAR10224.STD) + return transform + + @set_default_from_args('n_epochs') + def get_epochs(self): + return 50 + + @set_default_from_args('batch_size') + def get_batch_size(self): + return 32 diff --git a/datasets/seq_cifar10_224_rs.py b/datasets/seq_cifar10_224_rs.py new file mode 100644 index 00000000..77c2e817 --- /dev/null +++ b/datasets/seq_cifar10_224_rs.py @@ -0,0 +1,96 @@ +# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara. +# All rights reserved. +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +import torch.nn.functional as F +import torchvision.transforms as transforms + +from backbone.ResNetBottleneck import resnet50 +from datasets.seq_cifar10 import TCIFAR10, MyCIFAR10 +from datasets.seq_tinyimagenet import base_path +from datasets.transforms.denormalization import DeNormalize +from datasets.utils.continual_dataset import (ContinualDataset, + store_masked_loaders) +from datasets.utils import set_default_from_args + + +class SequentialCIFAR10224RS(ContinualDataset): + """Sequential CIFAR10 Dataset. The images are resized to 224x224. + Version with ResNet18 backbone. + + Args: + NAME (str): name of the dataset. + SETTING (str): setting of the dataset. + N_CLASSES_PER_TASK (int): number of classes per task. + N_TASKS (int): number of tasks. + N_CLASSES (int): number of classes. + SIZE (tuple): size of the images. + MEAN (tuple): mean of the dataset. + STD (tuple): standard deviation of the dataset. + TRANSFORM (torchvision.transforms): transformations to apply to the dataset. + """ + + NAME = 'seq-cifar10-224-rs' + SETTING = 'class-il' + N_CLASSES_PER_TASK = 2 + N_TASKS = 5 + N_CLASSES = N_CLASSES_PER_TASK * N_TASKS + MEAN, STD = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2615) + SIZE = (224, 224) + TRANSFORM = transforms.Compose( + [transforms.Resize(224), + transforms.RandomCrop(224, padding=28), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(MEAN, STD)]) + + TEST_TRANSFORM = transforms.Compose([transforms.Resize(224), transforms.ToTensor(), transforms.Normalize(MEAN, STD)]) + + def get_data_loaders(self) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]: + """Class method that returns the train and test loaders.""" + transform = self.TRANSFORM + + train_dataset = MyCIFAR10(base_path() + 'CIFAR10', train=True, + download=True, transform=transform) + test_dataset = TCIFAR10(base_path() + 'CIFAR10', train=False, + download=True, transform=self.TEST_TRANSFORM) + + train, test = store_masked_loaders(train_dataset, test_dataset, self) + return train, test + + @staticmethod + def get_transform(): + transform = transforms.Compose( + [transforms.ToPILImage(), SequentialCIFAR10224RS.TRANSFORM]) + return transform + + @staticmethod + def get_backbone(): + return resnet50(SequentialCIFAR10224RS.N_CLASSES_PER_TASK + * SequentialCIFAR10224RS.N_TASKS) + + @staticmethod + def get_loss(): + return F.cross_entropy + + @staticmethod + def get_normalization_transform(): + transform = transforms.Normalize(SequentialCIFAR10224RS.MEAN, SequentialCIFAR10224RS.STD) + return transform + + @staticmethod + def get_denormalization_transform(): + transform = DeNormalize(SequentialCIFAR10224RS.MEAN, SequentialCIFAR10224RS.STD) + return transform + + @set_default_from_args('n_epochs') + def get_epochs(self): + return 50 + + @set_default_from_args('batch_size') + def get_batch_size(self): + return 32 diff --git a/datasets/seq_cub200.py b/datasets/seq_cub200.py index 57d6d29c..07ba9ceb 100644 --- a/datasets/seq_cub200.py +++ b/datasets/seq_cub200.py @@ -9,12 +9,13 @@ from torch.utils.data.dataset import Dataset -from backbone.ResNet50 import resnet50 +from backbone.ResNetBottleneck import resnet50 from datasets.transforms.denormalization import DeNormalize from datasets.utils.continual_dataset import (ContinualDataset, store_masked_loaders) from utils import smart_joint from utils.conf import base_path +from datasets.utils import set_default_from_args class MyCUB200(Dataset): @@ -164,7 +165,7 @@ def get_data_loaders(self) -> Tuple[torch.utils.data.DataLoader, torch.utils.dat train_dataset = MyCUB200(base_path() + 'CUB200', train=True, download=True, transform=transform) test_dataset = CUB200(base_path() + 'CUB200', train=False, - download=True, transform=test_transform) + download=True, transform=test_transform) train, test = store_masked_loaders( train_dataset, test_dataset, self) @@ -197,10 +198,10 @@ def get_denormalization_transform(): transform = DeNormalize(SequentialCUB200.MEAN, SequentialCUB200.STD) return transform - @staticmethod - def get_batch_size(): + @set_default_from_args('batch_size') + def get_batch_size(self): return 16 - @staticmethod - def get_epochs(): + @set_default_from_args('n_epochs') + def get_epochs(self): return 30 diff --git a/datasets/seq_imagenet_r.py b/datasets/seq_imagenet_r.py index f0772b89..21765f2a 100644 --- a/datasets/seq_imagenet_r.py +++ b/datasets/seq_imagenet_r.py @@ -6,7 +6,6 @@ import numpy as np from utils.conf import base_path from PIL import Image -from datasets.utils.validation import get_train_val from datasets.utils.continual_dataset import ContinualDataset, store_masked_loaders from typing import Tuple from datasets.transforms.denormalization import DeNormalize @@ -16,6 +15,7 @@ import pickle from torchvision.transforms.functional import InterpolationMode from utils.prompt_templates import templates +from datasets.utils import set_default_from_args class MyImagenetR(Dataset): @@ -138,12 +138,9 @@ def get_data_loaders(self): train_dataset = MyImagenetR(base_path() + 'imagenet-r/', train=True, download=True, transform=transform) - if self.args.validation: - train_dataset, test_dataset = get_train_val(train_dataset, - test_transform, self.NAME) - else: - test_dataset = MyImagenetR(base_path() + 'imagenet-r/', train=False, - download=True, transform=test_transform) + + test_dataset = MyImagenetR(base_path() + 'imagenet-r/', train=False, + download=True, transform=test_transform) train, test = store_masked_loaders(train_dataset, test_dataset, self) return train, test @@ -189,12 +186,12 @@ def get_denormalization_transform(): (1, 1, 1)) return transform - @staticmethod - def get_epochs(): + @set_default_from_args('n_epochs') + def get_epochs(self): return 50 - @staticmethod - def get_batch_size(): + @set_default_from_args('batch_size') + def get_batch_size(self): return 32 @staticmethod diff --git a/datasets/seq_mnist.py b/datasets/seq_mnist.py index 4348f604..9bb5ed2c 100644 --- a/datasets/seq_mnist.py +++ b/datasets/seq_mnist.py @@ -15,6 +15,7 @@ from datasets.utils.continual_dataset import (ContinualDataset, store_masked_loaders) from utils.conf import base_path +from datasets.utils import set_default_from_args class MyMNIST(MNIST): @@ -108,10 +109,10 @@ def get_normalization_transform(): def get_denormalization_transform(): return None - @staticmethod - def get_batch_size(): + @set_default_from_args('batch_size') + def get_batch_size(self): return 64 - @staticmethod - def get_epochs(): + @set_default_from_args('n_epochs') + def get_epochs(self): return 1 diff --git a/datasets/seq_tinyimagenet.py b/datasets/seq_tinyimagenet.py index 9071b9f5..ac36d554 100644 --- a/datasets/seq_tinyimagenet.py +++ b/datasets/seq_tinyimagenet.py @@ -14,12 +14,13 @@ from PIL import Image from torch.utils.data import Dataset -from backbone.ResNet18 import resnet18 +from backbone.ResNetBlock import resnet18 from datasets.transforms.denormalization import DeNormalize from datasets.utils.continual_dataset import (ContinualDataset, store_masked_loaders) from utils import smart_joint from utils.conf import base_path +from datasets.utils import set_default_from_args class TinyImagenet(Dataset): @@ -177,10 +178,10 @@ def get_denormalization_transform(): transform = DeNormalize(SequentialTinyImagenet.MEAN, SequentialTinyImagenet.STD) return transform - @staticmethod - def get_epochs(): + @set_default_from_args('n_epochs') + def get_epochs(self): return 50 - @staticmethod - def get_batch_size(): + @set_default_from_args('batch_size') + def get_batch_size(self): return 32 diff --git a/datasets/utils/__init__.py b/datasets/utils/__init__.py index 920aa905..dbf74218 100644 --- a/datasets/utils/__init__.py +++ b/datasets/utils/__init__.py @@ -1,3 +1,51 @@ """ This package contains utility functions used by all datasets, including the base dataset class (ContinualDataset). """ + +import functools +import inspect + +# Default arguments defined by the datasets +DEFAULT_ARGS = {} + + +def is_static_call(*args) -> bool: + """ + Check if the function is called without any arguments. + + Returns: + bool: True if the function is called without any arguments, False otherwise. + """ + return len(args) == 0 + + +def set_default_from_args(arg_name: str): + """ + Decorator to define the default value of an argument of a given dataset. + + Args: + arg_name (str): The name of the argument to set the default value for. + + Returns: + function: The decorator to set the default value of the argument. + """ + + global DEFAULT_ARGS + caller = inspect.currentframe().f_back + caller_name = caller.f_locals['NAME'] + if caller_name not in DEFAULT_ARGS: + DEFAULT_ARGS[caller_name] = {} + + def decorator_set_default_from_args(func): + DEFAULT_ARGS[caller_name][arg_name] = func(None) + + @functools.wraps(func) + def wrapper(*args): + + if is_static_call(*args): + # if no arguments are passed, return the function + return func(None) + + return func(*args) + return wrapper + return decorator_set_default_from_args diff --git a/datasets/utils/continual_dataset.py b/datasets/utils/continual_dataset.py index f78f2098..0ffc3915 100644 --- a/datasets/utils/continual_dataset.py +++ b/datasets/utils/continual_dataset.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from argparse import Namespace +import sys from typing import Tuple import torch @@ -12,10 +13,12 @@ import torch.optim.lr_scheduler as scheds from torch.utils.data import DataLoader, Dataset +from datasets.utils.validation import get_validation_indexes from utils.conf import create_seeded_dataloader +from datasets.utils import DEFAULT_ARGS -class ContinualDataset: +class ContinualDataset(object): """ A base class for defining continual learning datasets. @@ -78,6 +81,29 @@ def __init__(self, args: Namespace) -> None: if not all((self.NAME, self.SETTING, self.N_CLASSES_PER_TASK, self.N_TASKS, self.SIZE, self.N_CLASSES)): raise NotImplementedError('The dataset must be initialized with all the required fields.') + def update_default_args(self): + """ + Updates the default arguments with the ones specified in the dataset class. + Default arguments are defined in the DEFAULT_ARGS dictionary and set by the 'set_default_from_args' decorator. + + Returns: + Namespace: the updated arguments + """ + + if self.args.dataset not in DEFAULT_ARGS: # no default args for this dataset + return self.args + + for k, v in DEFAULT_ARGS[self.args.dataset].items(): + assert hasattr(self.args, k), f'Argument {k} set by the `set_default_from_args` decorator is not present in the arguments.' + + if getattr(self.args, k) is None: + setattr(self.args, k, v) + else: + if getattr(self.args, k) != v: + print('Warning: {} set to {} instead of {}.'.format(k, getattr(self.args, k), v), file=sys.stderr) + + return self.args + def get_offsets(self, task_idx: int = None): """ Compute the start and end class index for the current task. @@ -150,19 +176,21 @@ def get_scheduler(model, args: Namespace) -> torch.optim.lr_scheduler._LRSchedul return sched return None + def get_iters(self): + """Returns the number of iterations to be used for the current dataset.""" + raise NotImplementedError('The dataset does not implement the method `get_iters` to set the default number of iterations.') + def get_epochs(self): """Returns the number of epochs to be used for the current dataset.""" - raise NotImplementedError + raise NotImplementedError('The dataset does not implement the method `get_epochs` to set the default number of epochs.') - @staticmethod - def get_batch_size(): + def get_batch_size(self): """Returns the batch size to be used for the current dataset.""" - raise NotImplementedError + raise NotImplementedError('The dataset does not implement the method `get_batch_size` to set the default batch size.') - @classmethod - def get_minibatch_size(cls): + def get_minibatch_size(self): """Returns the minibatch size to be used for the current dataset.""" - return cls.get_batch_size() + return self.get_batch_size() def _get_mask_unlabeled(train_dataset, setting: ContinualDataset): @@ -219,27 +247,32 @@ def store_masked_loaders(train_dataset: Dataset, test_dataset: Dataset, test_dataset.targets = setting.args.class_order[test_dataset.targets] if setting.args.validation: - n_samples = len(train_dataset) - n_samples_val = torch.div(n_samples, setting.args.validation, rounding_mode='floor').item() + train_idxs, val_idxs = get_validation_indexes(setting.args.validation, train_dataset, setting.args.seed) - train_idxs = torch.randperm(n_samples, generator=torch.Generator().manual_seed(setting._c_seed)).numpy() - val_idxs = train_idxs[:n_samples_val] - train_idxs = train_idxs[n_samples_val:] + test_dataset.data = train_dataset.data[val_idxs] + test_dataset.targets = train_dataset.targets[val_idxs] - train_dataset.data, test_dataset.data = train_dataset.data[train_idxs], train_dataset.data[val_idxs] - train_dataset.targets, test_dataset.targets = train_dataset.targets[train_idxs], train_dataset.targets[val_idxs] + train_dataset.data = train_dataset.data[train_idxs] + train_dataset.targets = train_dataset.targets[train_idxs] if setting.SETTING == 'class-il' or setting.SETTING == 'task-il': - train_mask = np.logical_and(np.array(train_dataset.targets) >= setting.i, - np.array(train_dataset.targets) < setting.i + setting.N_CLASSES_PER_TASK) - test_mask = np.logical_and(np.array(test_dataset.targets) >= setting.i, - np.array(test_dataset.targets) < setting.i + setting.N_CLASSES_PER_TASK) + train_mask = np.logical_and(train_dataset.targets >= setting.i, + train_dataset.targets < setting.i + setting.N_CLASSES_PER_TASK) + + if setting.args.validation_mode == 'current': + test_mask = np.logical_and(test_dataset.targets >= setting.i, + test_dataset.targets < setting.i + setting.N_CLASSES_PER_TASK) + elif setting.args.validation_mode == 'complete': + test_mask = np.logical_and(test_dataset.targets >= 0, + test_dataset.targets < setting.i + setting.N_CLASSES_PER_TASK) + else: + raise ValueError('Unknown validation mode: {}'.format(setting.args.validation_mode)) - train_dataset.data = train_dataset.data[train_mask] test_dataset.data = test_dataset.data[test_mask] + test_dataset.targets = test_dataset.targets[test_mask] + train_dataset.data = train_dataset.data[train_mask] train_dataset.targets = train_dataset.targets[train_mask] - test_dataset.targets = test_dataset.targets[test_mask] train_dataset, test_dataset = _prepare_data_loaders(train_dataset, test_dataset, setting) @@ -249,7 +282,7 @@ def store_masked_loaders(train_dataset: Dataset, test_dataset: Dataset, batch_size=setting.args.batch_size, shuffle=False) setting.test_loaders.append(test_loader) setting.train_loader = train_loader - + if setting.SETTING == 'task-il' or setting.SETTING == 'class-il': setting.i += setting.N_CLASSES_PER_TASK setting.c_task += 1 diff --git a/datasets/utils/gcl_dataset.py b/datasets/utils/gcl_dataset.py index 5c64c113..d4214764 100644 --- a/datasets/utils/gcl_dataset.py +++ b/datasets/utils/gcl_dataset.py @@ -35,8 +35,7 @@ def __init__(self, args: Namespace) -> None: if not all((self.NAME, self.SETTING, self.SIZE)): raise NotImplementedError('The dataset must be initialized with all the required fields.') - @staticmethod - def get_epochs(): + def get_epochs(self): """ A GCLDataset is not compatible with multiple epochs. """ diff --git a/datasets/utils/validation.py b/datasets/utils/validation.py index be79192d..3f34ba3c 100644 --- a/datasets/utils/validation.py +++ b/datasets/utils/validation.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import os -from typing import Optional +from typing import Optional, Tuple import numpy as np import torch @@ -48,6 +48,37 @@ def __getitem__(self, index): return img, target +def get_validation_indexes(validation_size: float, dataset: Dataset, seed=None) -> Tuple[Dataset, Dataset]: + """ + Returns the indexes of train and validation datasets from the given dataset, according to the validation size. + + Args: + validation_size (float): percentage of samples for each class to be used for validation (between 0 and 100) + dataset (Dataset): the dataset to split + seed (int): the seed for the random generator. If None, the seed is set to 0 + + Returns: + tuple: the train and validation dataset indexes + """ + seed = 0 if seed is None else seed + if validation_size < 1: + validation_size = round(validation_size * 100, 2) + + cls_ids, samples_per_class = np.unique(dataset.targets, return_counts=True) + n_samples_val_per_class = np.ceil(samples_per_class * (validation_size / 100)).astype(int) + + all_idxs = np.arange(len(dataset.targets)) + val_idxs, train_idxs = [], [] + for cls_id, n_samples, n_samples_val in zip(cls_ids, samples_per_class, n_samples_val_per_class): + cls_idxs = all_idxs[dataset.targets == cls_id] + idxs = torch.randperm(n_samples, generator=torch.Generator().manual_seed(seed)).numpy() + val_idxs.append(cls_idxs[idxs[:n_samples_val]]) + train_idxs.append(cls_idxs[idxs[n_samples_val:]]) + + train_idxs = np.concatenate(train_idxs) + val_idxs = np.concatenate(val_idxs) + + return train_idxs, val_idxs def get_train_val(train: Dataset, test_transform: nn.Module, dataset: str, val_perc: float = 0.1): @@ -72,12 +103,13 @@ def get_train_val(train: Dataset, test_transform: nn.Module, else: perm = torch.randperm(dataset_length) torch.save(perm, directory + file_name) - train.data = train.data[perm] - train.targets = np.array(train.targets)[perm] - test_dataset = ValidationDataset(train.data[:int(val_perc * dataset_length)], - train.targets[:int(val_perc * dataset_length)], + + train_idxs, val_idxs = get_validation_indexes(val_perc, train) + + test_dataset = ValidationDataset(train.data[val_idxs], + train.targets[val_idxs], transform=test_transform) - train.data = train.data[int(val_perc * dataset_length):] - train.targets = train.targets[int(val_perc * dataset_length):] + train.data = train.data[train_idxs] + train.targets = train.targets[train_idxs] return train, test_dataset diff --git a/docs/datasets/index.rst b/docs/datasets/index.rst index ac1db709..c34db1c8 100644 --- a/docs/datasets/index.rst +++ b/docs/datasets/index.rst @@ -73,6 +73,19 @@ and are defined in the **SETTING** attribute of each dataset. The following sett Mammoth datasets support the **joint** setting, which is a special case of the `class-il` setting where all the classes are available at each task. This is useful to compare the performance of a method on what is usually considered the *upper bound* for the `class-il` setting. To run an experiment on the **joint** setting, simply set the ``--joint`` to ``1``. This will automatically set the **N_CLASSES_PER_TASK** attribute to the total number of classes in the dataset and the **TASKS** attribute to ``1``. +Default arguments and command line +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Besides **get_epochs** and **get_batch_size**, datasets can define default arguments that are used to set the default values for the command line arguments. +This is done with the **set_default_from_args** decorator, which takes the name of the command line argument as input. For example, the following code sets the default value for the `--label_perc` argument: + +.. code-block:: python + + @set_default_from_args('--label_perc') + def get_label_perc(self): + return 0.5 + + Steps to create a new dataset: ------------------------------ diff --git a/docs/getting_started/checkpoints.rst b/docs/getting_started/checkpoints.rst index 7dfd6cb9..dc38b9a6 100644 --- a/docs/getting_started/checkpoints.rst +++ b/docs/getting_started/checkpoints.rst @@ -3,6 +3,16 @@ Load and save checkpoints Loading and saving checkpoints is handeled automatically in :ref:`module-training` by supplying the ``--savecheck`` and ``--loadcheck`` arguments. +For example, to save a checkpoint after training, simply run the following command: +.. code-block:: python + + python utils/main.py --savecheck=1 --model=sgd --dataset=seq-cifar10 --lr=0.1 + +This will save the checkpoint in the ``checkpoints`` folder. To load the checkpoint, simply run the following command: +.. code-block:: python + + python utils/main.py --loadcheck=.pt --model=sgd --dataset=seq-cifar10 --lr=0.1 + .. rubric:: Checkpoint save format Mammoth saves checkpoints in the ``checkpoints`` folder, with a separate checkpoint file for each task. The checkpoint file follows the format: ``[]_____.pt``. @@ -33,10 +43,10 @@ Inside the checkpoint file, the following information is saved: - ``results``: all the metrics mesured up to the current task and the state of the logger. This information is necessary in order to continue training from the last checkpoint. -.. rubric:: Checkpoint loading +.. rubric:: Additional info on checkpoint loading -Mammoth supports loading checkpoint both from the local machine and from a remote machine using the ``--loadcheck`` argument. To load a checkpoint from a remote machine, simply supply the ``--loadcheck`` with the URL of the checkpoint file. +Mammoth supports loading checkpoint both from the local machine **and from a remote machine** using the ``--loadcheck`` argument. To load a checkpoint from a remote machine, simply supply the ``--loadcheck`` with the URL of the checkpoint file. Checkpoints can be loaded either following the mammoth format (defined above) or from a simple ``.pt`` file. In the latter case, the checkpoint file should contain all the parameters of the *backbone* of the model. The other parameters (optimizer, scheduler, etc.) will be initialized from scratch. -The loading functions are available in :ref:`module-checkpoints` and should take care of loading all the parameters regardless of the presence of module parallelism (see :ref:`module-distributed-training`). \ No newline at end of file +The loading functions are available in :ref:`module-checkpoints` and should take care of loading all the parameters regardless of the presence of module parallelism (see :ref:`module-fast-training`). \ No newline at end of file diff --git a/docs/getting_started/distributed_training.rst b/docs/getting_started/distributed_training.rst deleted file mode 100644 index 4d749eda..00000000 --- a/docs/getting_started/distributed_training.rst +++ /dev/null @@ -1,11 +0,0 @@ -.. _module-distributed-training: - -Distributed training -==================== - -Mammoth supports distributed training via `DataParallel `_. To use it, simply pass the `--distributed=dp` argument to ``utils/main.py``. This will automatically use all available GPUs on the machine using the **make_dp** function in :ref:`module-distributed`. - -DataParallel training **splits the batch** across GPUs and performs the forward and backward passes on each GPU. The gradients are then **averaged** across GPUs and the model parameters are updated. This is the simplest form of distributed training supported by PyTorch and is the only one supported by Mammoth as of now. - -.. important:: - As of now, Mammoth only supports DataParallel training. This is due to the difficulty of synchronizing the memory buffer across multiple GPUs after each batch. However, experimental support for `DistributedDataParallel `_ training in a `slurm `_ cluster is available in the :ref:`module-distributed` module via the **make_ddp** function. \ No newline at end of file diff --git a/docs/getting_started/fast_training.rst b/docs/getting_started/fast_training.rst new file mode 100644 index 00000000..95d2f05b --- /dev/null +++ b/docs/getting_started/fast_training.rst @@ -0,0 +1,27 @@ +.. _module-fast-training: + +Fast training \& optimizations +============================== + +.. important:: + The optimizations described in this section require an NVIDIA GPU with the `Ampere architecture `_ (RTX 30xx series or newer) and the `CUDA Toolkit `_ installed. If you do not have an Ampere GPU, you can still use Mammoth without these optimizations. + +Mammoth provides a number of optimizations to speed up training. These are **disabled** by default (mainly to improve ease of debugging), but can be enabled by passing the `--code_optimization` (or `-O`) flag to ``utils/main.py``. The available optimizations are: + +* **0**: No optimization (default) +* **1**: Use the ``TF32`` data type for training IF IT IS AVAILABLE (*i.e.*, sets the `torch.set_float32_matmul_precision` to `high`). **This will fall back to FP32 if TF32 is not available**. +* **2**: Use the ``BF16`` data type for training (*i.e.*, sets the `torch.set_bf16_cvt_precision` to `medium`). **This will throw an error if the GPU does not support BF16**. +* **3**: Same as *2*, but also includes ``torch.compile``. This option has some caveats: + - It is only available on Linux (check `this issue `_ for updates). + - It does not work if the model *changes* during training. This includes increasing the number of classifiers, prompts, etc. + - It may not give a significant speedup for small models. + +Distributed training +==================== + +Mammoth supports distributed training via `DataParallel `_. To use it, simply pass the `--distributed=dp` argument to ``utils/main.py``. This will automatically use all available GPUs on the machine using the **make_dp** function in :ref:`module-distributed`. + +DataParallel training **splits the batch** across GPUs and performs the forward and backward passes on each GPU. The gradients are then **averaged** across GPUs and the model parameters are updated. This is the simplest form of distributed training supported by PyTorch and is the only one supported by Mammoth as of now. + +.. important:: + As of now, Mammoth only supports DataParallel training. This is due to the difficulty of synchronizing the memory buffer across multiple GPUs after each batch. However, experimental support for `DistributedDataParallel `_ training in a `slurm `_ cluster is available in the :ref:`module-distributed` module via the **make_ddp** function. \ No newline at end of file diff --git a/docs/getting_started/validation.rst b/docs/getting_started/validation.rst new file mode 100644 index 00000000..22f0bdaf --- /dev/null +++ b/docs/getting_started/validation.rst @@ -0,0 +1,50 @@ +.. _module-validation: + +Training, Validation, and Testing +================================= + +During each task, Mammoth trains on the current data until some stopping criterion is met. +Currently, Mammoth supports 3 types of stopping criteria, which can be chosen using the ``--fitting_mode`` command line argument. The three types are ``epochs``, ``iters``, and ``early_stopping``. The default is ``epochs``. + +.. rubric:: Criterion by epochs (``--fitting_modeepochs``) + +This is the default option, for which training stops after a fixed number of **epochs**. The number of epochs can be set using the ``--n_epochs`` command line argument. Note that most datasets indicate the default number of epochs via the `set_default_from_args` decorator (see :ref:`module-datasets` for more information). + +.. rubric:: Criterion by iterations (``--fitting_modeiters``) + +This option stops training after a fixed number of **iterations**. The number of iterations can be set using the ``--n_iters`` command line argument. In addition, a default value for each dataset can be set using the `set_default_from_args` decorator. For example, to set the default number of iterations to 1000 for a particular dataset you can use the following code, adding it to the dataset class definition: + +.. code-block:: python + + @set_default_from_args('n_iters') + def get_iters(self): + return 1000 + +.. rubric:: Early stopping (``--fitting_mode=early_stopping``) + +This option is the most flexible, as it allows training to continue until a certain stopping criterion is met. The criterion can be based either on the loss function or on the accuracy by setting the ``--early_stopping_metric`` command line argument (default is by ``loss``). The number of epochs to wait before stopping can be set using the ``--early_stopping_patience`` command line argument. The default value is 5. + +.. note:: + + The early stopping criterion is based on the chosen validation set (see next section). + +In addition, the early stopping criterion supports the following options: + +* ``early_stopping_freq`` (default is 1): the frequency at which the early stopping criterion is checked. For example, if ``early_stopping_freq`` is set to 5, the criterion is checked every 5 epochs. +* ``early_stopping_epsilon`` (default is 1e-6): the minimum improvement in the validation loss or accuracy that is considered significant. If the improvement is less than ``early_stopping_epsilon``, the training stops. + +Validation +---------- + +During training, Mammoth uses a validation set to monitor the performance of the model. By default, the validation set is **disabled**, meaning that performance is monitored on the **test** set. This is done in line with most CL literature, which uses the test set for validation as it is not trivial to define a validation set for CL tasks. In particular, two options may be possible: + +1. *The validation set includes data* **only of the current task**: this is the most straightforward option, but has the disantvantage of producing a higher degree of forgetting on past tasks as the model's objective ignores past data. +2. *The validation set includes data* **of all seen tasks**: this option should produce a more balanced result, but conflicts with the CL setting, as the model should not have access to data from past tasks. However, since most CL works only focus on maximizing the accuracy on all tasks after having seen all of them, this option is the most common in the literature. + +In Mammoth, the use of a validation set can be enabled by specifiying the percentage of the training set that is used as the validation set using the ``--validation`` command line argument. For example, to use 10% of the training set as the validation set, you can use the following command: + +.. code-block:: bash + + python main.py --validation 10 + +As for the choice of strategy to build the validation set, Mammoth supports both options described above using the ``--validation_mode`` command line argument. The default is ``current``, meaning that the validation set includes only data of the current task. If you want to use a validation set that includes data of all seen tasks, you can set ``--validation_mode`` to ``complete``. diff --git a/docs/index.rst b/docs/index.rst index 4c2ff9a5..dbdeee1a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -22,8 +22,9 @@ :caption: Getting started: getting_started/index.rst + getting_started/validation.rst getting_started/checkpoints.rst - getting_started/distributed_training.rst + getting_started/fast_training.rst getting_started/scripts.rst Parseval diff --git a/docs/models/index.rst b/docs/models/index.rst index ac228cba..6ba50f33 100644 --- a/docs/models/index.rst +++ b/docs/models/index.rst @@ -91,9 +91,9 @@ The base class **ContinualModel** provides a few properties that are automatical .. admonition:: Transforms and dataset-related Attributes - - **transform**: the transform applied to the input data. This attribute is automatically set during the initialization of the model and is defined by the chosen **dataset** (see :ref:`module-datasets` for more details). + - **transform**: the transform applied to the input data. This attribute is automatically set during the initialization of the model and is defined by the chosen **dataset** (see :ref:`module-datasets` for more details). In most cases, this is implemented as a `kornia `_ transform (translated from PIL thanks to `to_kornia_transform` in :ref:`module-kornia_utils`). However, if a transform is not supported by the **to_kornia_transform**, it is implemented as `PIL `_. - - **weak_transform**: this function is used to apply a new transform to a `torch.Tensor `_. In most cases, this is implemented as a `kornia `_ transform. However, if a transform is not supported by the **to_kornia_transform**, it is implemented as `PIL `_. + - **original_transform**: the original transform defined by the chosen **dataset**. This is implemented as a `PIL `_ transform (and not translated into `kornia` as the **transform**). - **normalization_transform**: the transform used to normalize the input data. As for the **weak_transform**, this is implemented as a `kornia `_ transform if possible, otherwise it is implemented as `PIL `_. diff --git a/docs/readme.rst b/docs/readme.rst index d74cf12e..8bad0be3 100644 --- a/docs/readme.rst +++ b/docs/readme.rst @@ -16,7 +16,9 @@ Idelly, all the code necessary to run the experiments is included *in the reposi With Mammoth, nothing is set in stone. You can easily add new models, datasets, training strategies, or functionalities. -**NEW**: Join our Discord Server for all your Mammoth-related questions → ![Discord Shield](https://discordapp.com/api/guilds/1164956257392799860/widget.png?style=shield) +**NEW**: Join our Discord Server for all your Mammoth-related questions! + +.. image:: https://discordapp.com/api/guilds/1164956257392799860/widget.png?style=shield .. list-table:: :widths: 15 15 15 15 15 15 @@ -161,7 +163,7 @@ Our Papers Other Awesome CL works using Mammoth ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. admonition:: +.. important:: **Get in touch if we missed your awesome work!** diff --git a/docs/utils/args.rst b/docs/utils/args.rst index f8411906..528dd912 100644 --- a/docs/utils/args.rst +++ b/docs/utils/args.rst @@ -5,243 +5,288 @@ Arguments .. rubric:: EXPERIMENT-RELATED ARGS +.. rubric:: Experiment arguments + +*Arguments used to define the experiment settings.* + **\-\-dataset** : *Help*: Which dataset to perform experiments on. - Default: None - - Choices: seq-tinyimg, seq-tinyimg-r, perm-mnist, seq-cifar10, seq-cifar100-224, seq-cub200, rot-mnist, seq-cifar100, seq-cifar100-224-rs, seq-mnist, mnist-360 - -**\-\-model** : + - Choices: mnist-360, perm-mnist, rot-mnist, seq-cifar10, seq-cifar100, seq-cifar100-224, seq-cifar100-224-rs, seq-cifar10-224, seq-cifar10-224-rs, seq-cub200, seq-imagenet-r, seq-mnist, seq-tinyimg, seq-tinyimg-r +**\-\-model** : *Help*: Model name. - Default: None - - Choices: agem, agem-r, ewc-on, derpp-lider, gdumb-lider, slca, dualprompt, si, bic, er-ace, fdr, gdumb, gem, gss, joint-gcl, lwf, mer, rpc, twf, ccic, der, derpp, er, hal, icarl, l2p, lucir, lwf-mc, sgd, xder, xder-ce, xder-rpc, pnn, er-ace-lider, icarl-lider, coda-prompt - + - Choices: agem, agem-r, bic, ccic, coda-prompt, der, derpp, derpp-lider, dualprompt, er, er-ace, er-ace-lider, ewc-on, fdr, gdumb, gdumb-lider, gem, gss, hal, icarl, icarl-lider, joint-gcl, l2p, lucir, lwf, lwf-mc, mer, pnn, rpc, sgd, si, slca, twf, xder, xder-ce, xder-rpc **\-\-lr** : *Help*: Learning rate. - Default: None - Choices: +**\-\-batch_size** : + *Help*: Batch size. -**\-\-optimizer** : - *Help*: Optimizer. - - - Default: sgd - - - Choices: sgd, adam, adamw + - Default: None -**\-\-optim_wd** : - *Help*: optimizer weight decay. + - Choices: +**\-\-label_perc** : + *Help*: Percentage in (0-1] of labeled examples per task. - - Default: 0.0 + - Default: 1 - Choices: +**\-\-joint** : + *Help*: Train model on Joint (single task)? -**\-\-optim_mom** : - *Help*: optimizer momentum. + - Default: 0 - - Default: 0.0 + - Choices: 0, 1 - - Choices: +.. rubric:: Validation and fitting arguments -**\-\-optim_nesterov** : - *Help*: optimizer nesterov momentum. +*Arguments used to define the validation strategy and the method used to fit the model.* - - Default: 0 +**\-\-validation** : + *Help*: Percentage of samples FOR EACH CLASS drawn from the training set to build the validation set. + + - Default: None - Choices: +**\-\-validation_mode** : + *Help*: Mode used for validation. Must be used in combination with `validation` argument. Possible values: - `current`: uses only the current task for validation (default). - `complete`: uses data from both current and past tasks for validation. -**\-\-lr_scheduler** : - *Help*: Learning rate scheduler. + - Default: current - - Default: None + - Choices: complete, current +**\-\-fitting_mode** : + *Help*: Strategy used for fitting the model. Possible values: - `epochs`: fits the model for a fixed number of epochs (default). NOTE: this option is controlled by the `n_epochs` argument. - `iters`: fits the model for a fixed number of iterations. NOTE: this option is controlled by the `n_iters` argument. - `early_stopping`: fits the model until early stopping criteria are met. This option requires a validation set (see `validation` argument). The early stopping criteria are: if the validation loss does not decrease for `early_stopping_patience` epochs, the training stops. - - Choices: + - Default: epochs -**\-\-lr_milestones** : - *Help*: Learning rate scheduler milestones (used if `lr_scheduler=multisteplr`). + - Choices: epochs, iters, time, early_stopping +**\-\-early_stopping_patience** : + *Help*: Number of epochs to wait before stopping the training if the validation loss does not decrease. Used only if `fitting_mode=early_stopping`. - - Default: [] + - Default: 5 - Choices: +**\-\-early_stopping_metric** : + *Help*: Metric used for early stopping. Used only if `fitting_mode=early_stopping`. -**\-\-sched_multistep_lr_gamma** : - *Help*: Learning rate scheduler gamma (used if `lr_scheduler=multisteplr`). + - Default: loss - - Default: 0.1 + - Choices: loss, accuracy +**\-\-early_stopping_freq** : + *Help*: Frequency of validation evaluation. Used only if `fitting_mode=early_stopping`. + + - Default: 1 - Choices: +**\-\-early_stopping_epsilon** : + *Help*: Minimum improvement required to consider a new best model. Used only if `fitting_mode=early_stopping`. + - Default: 1e-06 + + - Choices: **\-\-n_epochs** : - *Help*: Number of epochs. + *Help*: Number of epochs. Used only if `fitting_mode=epochs`. - Default: None - Choices: - -**\-\-batch_size** : - *Help*: Batch size. +**\-\-n_iters** : + *Help*: Number of iterations. Used only if `fitting_mode=iters`. - Default: None - Choices: -**\-\-distributed** : - *Help*: Enable distributed training? +.. rubric:: Optimizer and learning rate scheduler arguments - - Default: no +*Arguments used to define the optimizer and the learning rate scheduler.* - - Choices: no, dp, ddp - -**\-\-savecheck** : None - *Help*: Save checkpoint? - - - Default: False +**\-\-optimizer** : + *Help*: Optimizer. - - Choices: + - Default: sgd -**\-\-loadcheck** : - *Help*: Path of the checkpoint to load (.pt file for the specific task) + - Choices: sgd, adam, adamw +**\-\-optim_wd** : + *Help*: optimizer weight decay. - - Default: None + - Default: 0.0 - Choices: +**\-\-optim_mom** : + *Help*: optimizer momentum. -**\-\-ckpt_name** : - *Help*: (optional) checkpoint save name. - - - Default: None + - Default: 0.0 - Choices: +**\-\-optim_nesterov** : + *Help*: optimizer nesterov momentum. -**\-\-start_from** : - *Help*: Task to start from - - - Default: None + - Default: 0 - Choices: - -**\-\-stop_after** : - *Help*: Task limit +**\-\-lr_scheduler** : + *Help*: Learning rate scheduler. - Default: None - Choices: +**\-\-lr_milestones** : + *Help*: Learning rate scheduler milestones (used if `lr_scheduler=multisteplr`). -**\-\-joint** : - *Help*: Train model on Joint (single task)? - - - Default: 0 - - - Choices: 0, 1 + - Default: [] -**\-\-label_perc** : - *Help*: Percentage in (0-1] of labeled examples per task. + - Choices: +**\-\-sched_multistep_lr_gamma** : + *Help*: Learning rate scheduler gamma (used if `lr_scheduler=multisteplr`). - - Default: 1 + - Default: 0.1 - Choices: .. rubric:: MANAGEMENT ARGS +.. rubric:: Management arguments + +*Generic arguments to manage the experiment reproducibility, logging, debugging, etc.* + **\-\-seed** : - *Help*: The random seed. + *Help*: The random seed. If not provided, a random seed will be used. - Default: None - Choices: - **\-\-permute_classes** : - *Help*: Permute classes before splitting tasks (applies seed before permute if seed is present)? + *Help*: Permute classes before splitting into tasks? This applies the seed before permuting if the `seed` argument is present. - Default: 0 - Choices: 0, 1 - **\-\-base_path** : *Help*: The base path where to save datasets, logs, results. - Default: ./data/ - Choices: - **\-\-notes** : - *Help*: Notes for this run. + *Help*: Helper argument to include notes for this run. Example: distinguish between different versions of a model and allow separation of results - Default: None - Choices: +**\-\-eval_epochs** : + *Help*: Perform inference on validation every `eval_epochs` epochs. If not provided, the model is evaluated ONLY at the end of each task. + + - Default: None + - Choices: **\-\-non_verbose** : *Help*: Make progress bars non verbose - Default: 0 - Choices: 0, 1 - **\-\-disable_log** : *Help*: Disable logging? - Default: 0 - Choices: 0, 1 - **\-\-num_workers** : *Help*: Number of workers for the dataloaders (default=infer from number of cpus). - Default: None - Choices: +**\-\-enable_other_metrics** : + *Help*: Enable computing additional metrics: forward and backward transfer. -**\-\-validation** : - *Help*: Percentage of validation set drawn from the training set. + - Default: 0 - - Default: None + - Choices: 0, 1 +**\-\-debug_mode** : + *Help*: Run only a few training steps per epoch. This also disables logging on wandb. - - Choices: + - Default: 0 -**\-\-enable_other_metrics** : - *Help*: Enable computing additional metrics: forward and backward transfer. + - Choices: 0, 1 +**\-\-inference_only** : + *Help*: Perform inference only for each task (no training). - Default: 0 - Choices: 0, 1 +**\-\-code_optimization** : + *Help*: Optimization level for the code.0: no optimization.1: Use TF32, if available.2: Use BF16, if available.3: Use BF16 and `torch.compile`. BEWARE: torch.compile may break your code if you change the model after the first run! Use with caution. -**\-\-debug_mode** : - *Help*: Run only a few forward steps per epoch + - Default: 0 + + - Choices: 0, 1, 2, 3 +**\-\-distributed** : + *Help*: Enable distributed training? + + - Default: no + + - Choices: no, dp, ddp +**\-\-savecheck** : + *Help*: Save checkpoint? - Default: 0 - Choices: 0, 1 +**\-\-loadcheck** : + *Help*: Path of the checkpoint to load (.pt file for the specific task) -**\-\-wandb_entity** : - *Help*: Wandb entity + - Default: None + + - Choices: +**\-\-ckpt_name** : + *Help*: (optional) checkpoint save name. - Default: None - Choices: +**\-\-start_from** : + *Help*: Task to start from -**\-\-wandb_project** : - *Help*: Wandb project name + - Default: None - - Default: mammoth + - Choices: +**\-\-stop_after** : + *Help*: Task limit + + - Default: None - Choices: -**\-\-eval_epochs** : - *Help*: Perform inference intra-task at every `eval_epochs`. +.. rubric:: Wandb arguments + +*Arguments to manage logging on Wandb.* + +**\-\-wandb_name** : + *Help*: Wandb name for this run. Overrides the default name (`args.model`). - Default: None - Choices: +**\-\-wandb_entity** : + *Help*: Wandb entity -**\-\-inference_only** : None - *Help*: Perform inference only for each task (no training). + - Default: None - - Default: False + - Choices: +**\-\-wandb_project** : + *Help*: Wandb project name + + - Default: mammoth - Choices: diff --git a/docs/utils/index.rst b/docs/utils/index.rst index 096fbfb6..cafeab82 100644 --- a/docs/utils/index.rst +++ b/docs/utils/index.rst @@ -38,10 +38,10 @@ Other arguments such as the size of the training batch and the number of epochs python utils/main.py --dataset seq-cifar10 --model der --buffer_size 500 --lr 0.03 --batch_size 128 --epochs 10 .. note:: - To ease hyper-parameter tuning, all boolean arguments follow the convention: ``--=1`` for ``True`` and ``--=0`` for ``False``. The only exceptions are ``--savecheck`` and ``--inference_only``, as they should not be included in the hyper-parameter search. + To ease hyper-parameter tuning, all boolean arguments follow the convention: ``--=1`` for ``True`` and ``--=0`` for ``False``. Other useful arguments -~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~ * ``--debug_mode``: If set to ``1``, the model will run for only a few iterations per each epoch and will disable WandB logging. This is useful for debugging. diff --git a/models/__init__.py b/models/__init__.py index 0cdb90dc..a100ba9d 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -88,7 +88,7 @@ def _get_names(): names[c.NAME.replace('_', '-')] = c except Exception as e: warn_once("Error in model", model) - warn_once(e) + warn_once("\t-", e) names[model.replace('_', '-')] = e return names diff --git a/models/bic.py b/models/bic.py index 5ab64f38..2a9ff891 100644 --- a/models/bic.py +++ b/models/bic.py @@ -44,7 +44,6 @@ def __init__(self, backbone, loss, args, transform): super().__init__(backbone, loss, args, transform) dd = get_dataset(args) - self.transform = transform self.buffer = Buffer(self.args.buffer_size) self.lamda = 0 diff --git a/models/ccic.py b/models/ccic.py index 35ac7b61..4cdae0be 100644 --- a/models/ccic.py +++ b/models/ccic.py @@ -130,7 +130,7 @@ def observe(self, inputs, labels, not_aug_inputs, epoch=None): real_mask = mask[:real_batch_size] if (~real_mask).sum() > 0: - unsup_aug_inputs = self.weak_transform(not_aug_inputs[~real_mask].repeat_interleave(self.args.k_aug, 0)) + unsup_aug_inputs = self.transform(not_aug_inputs[~real_mask].repeat_interleave(self.args.k_aug, 0)) else: unsup_aug_inputs = torch.zeros((0,)).to(self.device) diff --git a/models/coda_prompt_utils/model.py b/models/coda_prompt_utils/model.py index f2d3696e..676c1fa7 100644 --- a/models/coda_prompt_utils/model.py +++ b/models/coda_prompt_utils/model.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from timm.models.vision_transformer import vit_base_patch16_224 +from backbone.vit import create_vision_transformer from models.coda_prompt_utils.vit import VisionTransformer import copy @@ -216,14 +216,16 @@ def __init__(self, num_classes=10, pt=False, prompt_param=None): # get feature encoder vit_model = VisionTransformer(img_size=224, patch_size=16, embed_dim=768, depth=12, - num_heads=12, ckpt_layer=0, - drop_path_rate=0) + num_heads=12, drop_path_rate=0) if pt: - load_dict = vit_base_patch16_224(pretrained=True).state_dict() - del load_dict['head.weight'] - del load_dict['head.bias'] - vit_model.load_state_dict(load_dict) + load_dict = create_vision_transformer('vit_base_patch16_224', base_class=VisionTransformer, pretrained=True, num_classes=0).state_dict() + if 'head.weight' in load_dict: + del load_dict['head.weight'] + del load_dict['head.bias'] + missing, unexpected = vit_model.load_state_dict(load_dict, strict=False) + assert len([m for m in missing if 'head' not in m]) == 0, f"Missing keys: {missing}" + assert len(unexpected) == 0, f"Unexpected keys: {unexpected}" # classifier self.last = nn.Linear(768, num_classes) diff --git a/models/coda_prompt_utils/vit.py b/models/coda_prompt_utils/vit.py index c2fa3bbb..412c93cd 100644 --- a/models/coda_prompt_utils/vit.py +++ b/models/coda_prompt_utils/vit.py @@ -5,32 +5,12 @@ import torch import torch.nn as nn +import torch.nn.functional as F from functools import partial -from timm.models.vision_transformer import PatchEmbed from timm.models.layers import trunc_normal_, DropPath - -class Mlp(nn.Module): - """ MLP as used in Vision Transformer, MLP-Mixer and related networks - """ - - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x +from backbone.vit import Mlp, VisionTransformer as MammothVP class Attention(nn.Module): @@ -59,7 +39,7 @@ def save_attention_map(self, attention_map): def get_attention_map(self): return self.attention_map - def forward(self, x, register_hook=False, prompt=None): + def forward(self, x, prompt=None): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) @@ -71,15 +51,13 @@ def forward(self, x, register_hook=False, prompt=None): k = torch.cat((pk, k), dim=2) v = torch.cat((pv, v), dim=2) - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) + x = F.scaled_dot_product_attention(q, k, v, scale=self.scale, dropout_p=self.attn_drop.p) + # attn = (q @ k.transpose(-2, -1)) * self.scale + # attn = attn.softmax(dim=-1) + # attn = self.attn_drop(attn) + # x = (attn @ v) - if register_hook: - self.save_attention_map(attn) - attn.register_hook(self.save_attn_gradients) - - x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x @@ -99,61 +77,33 @@ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - def forward(self, x, register_hook=False, prompt=None): - x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook, prompt=prompt)) + def forward(self, x, prompt=None): + x = x + self.drop_path(self.attn(self.norm1(x), prompt=prompt)) x = x + self.drop_path(self.mlp(self.norm2(x))) return x -class VisionTransformer(nn.Module): - """ Vision Transformer - A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - - https://arxiv.org/abs/2010.11929 - """ - - def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, - num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, - ckpt_layer=0): - """ - Args: - img_size (int, tuple): input image size - patch_size (int, tuple): patch size - in_chans (int): number of input channels - num_classes (int): number of classes for classification head - embed_dim (int): embedding dimension - depth (int): depth of transformer - num_heads (int): number of attention heads - mlp_ratio (int): ratio of mlp hidden dim to embedding dim - qkv_bias (bool): enable bias for qkv if True - qk_scale (float): override default qk scale of head_dim ** -0.5 if set - representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set - drop_rate (float): dropout rate - attn_drop_rate (float): attention dropout rate - drop_path_rate (float): stochastic depth rate - norm_layer: (nn.Module): normalization layer - """ - super().__init__() - self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models - norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) +class VisionTransformer(MammothVP): + def __init__(self, qk_scale=None, args=None, **kwargs): - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + super().__init__(args=args, **kwargs) num_patches = self.patch_embed.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim)) - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) - self.pos_drop = nn.Dropout(p=drop_rate) - - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.ModuleList([ Block( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + dim=self.embed_dim, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + drop=self.pos_drop.p, + attn_drop=self.attn_drop_rate, + drop_path=self.dpr[i], + norm_layer=self.norm_layer, + act_layer=self.act_layer ) - for i in range(depth)]) - self.norm = norm_layer(embed_dim) + for i in range(self.depth)]) trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) @@ -168,11 +118,7 @@ def _init_weights(self, m): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) - @torch.jit.ignore - def no_weight_decay(self): - return {'pos_embed', 'cls_token'} - - def forward(self, x, register_blk=-1, prompt=None, q=None, train=False, task_id=None): + def forward(self, x, prompt=None, q=None, train=False, task_id=None): B = x.shape[0] x = self.patch_embed(x) @@ -194,7 +140,7 @@ def forward(self, x, register_blk=-1, prompt=None, q=None, train=False, task_id= else: p_list = None - x = blk(x, register_blk == i, prompt=p_list) + x = blk(x, prompt=p_list) x = self.norm(x) diff --git a/models/dualprompt_utils/attention.py b/models/dualprompt_utils/attention.py index b5291ff7..6f38b8de 100644 --- a/models/dualprompt_utils/attention.py +++ b/models/dualprompt_utils/attention.py @@ -24,7 +24,7 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) - def forward(self, x, prompt): + def forward(self, x, prompt=None): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) diff --git a/models/dualprompt_utils/vision_transformer.py b/models/dualprompt_utils/vision_transformer.py index ea73bcf8..9dd510b9 100644 --- a/models/dualprompt_utils/vision_transformer.py +++ b/models/dualprompt_utils/vision_transformer.py @@ -1,236 +1,46 @@ -""" Vision Transformer (ViT) in PyTorch - -A clone of ViT from timm's implementation, with dualprompt implementation. - -Copyright 2020, Ross Wightman -# ------------------------------------------ -# Modification: -# Added code for dualprompt implementation -# -- Jaeho Lee, dlwogh9344@khu.ac.kr -# ------------------------------------------ -""" import math import logging -from functools import partial -from collections import OrderedDict import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint -from timm.models.helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv, checkpoint_seq -from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ +from timm.models.helpers import named_apply +from timm.models.layers import trunc_normal_ +from backbone.vit import Attention, create_vision_transformer, VisionTransformer as MammothVP, get_init_weights_vit from models.dualprompt_utils.prompt import EPrompt from models.dualprompt_utils.attention import PreT_Attention _logger = logging.getLogger(__name__) -class Attention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): - super().__init__() - assert dim % num_heads == 0, 'dim should be divisible by num_heads' - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = head_dim ** -0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x, *args): - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) - - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -class LayerScale(nn.Module): - def __init__(self, dim, init_values=1e-5, inplace=False): - super().__init__() - self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim)) - - def forward(self, x): - return x.mul_(self.gamma) if self.inplace else x * self.gamma - - -class Block(nn.Module): +class VisionTransformer(MammothVP): def __init__( - self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_layer=Attention): - super().__init__() - self.norm1 = norm_layer(dim) - self.attn = attn_layer(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) - self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() - # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here - self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - - self.norm2 = norm_layer(dim) - self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) - self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() - self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - - def forward(self, x, prompt=None): - x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), prompt))) - x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) - return x - - -class ResPostBlock(nn.Module): - - def __init__( - self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.init_values = init_values - - self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) - self.norm1 = norm_layer(dim) - self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - - self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) - self.norm2 = norm_layer(dim) - self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - - self.init_weights() - - def init_weights(self): - # NOTE this init overrides that base model init with specific changes for the block type - if self.init_values is not None: - nn.init.constant_(self.norm1.weight, self.init_values) - nn.init.constant_(self.norm2.weight, self.init_values) - - def forward(self, x): - x = x + self.drop_path1(self.norm1(self.attn(x))) - x = x + self.drop_path2(self.norm2(self.mlp(x))) - return x - - -class ParallelBlock(nn.Module): + self, prompt_length=None, embedding_key='cls', prompt_init='uniform', prompt_pool=False, prompt_key=False, pool_size=None, + top_k=None, batchwise_prompt=False, prompt_key_init='uniform', head_type='token', use_prompt_mask=False, + use_g_prompt=False, g_prompt_length=None, g_prompt_layer_idx=None, use_prefix_tune_for_g_prompt=False, + use_e_prompt=False, e_prompt_layer_idx=None, use_prefix_tune_for_e_prompt=False, same_key_value=False, args=None, **kwargs): - def __init__( - self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False, init_values=None, - drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.num_parallel = num_parallel - self.attns = nn.ModuleList() - self.ffns = nn.ModuleList() - for _ in range(num_parallel): - self.attns.append(nn.Sequential(OrderedDict([ - ('norm', norm_layer(dim)), - ('attn', Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)), - ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), - ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) - ]))) - self.ffns.append(nn.Sequential(OrderedDict([ - ('norm', norm_layer(dim)), - ('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)), - ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), - ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) - ]))) - - def _forward_jit(self, x): - x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0) - x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0) - return x - - @torch.jit.ignore - def _forward(self, x): - x = x + sum(attn(x) for attn in self.attns) - x = x + sum(ffn(x) for ffn in self.ffns) - return x - - def forward(self, x): - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return self._forward_jit(x) + if not (use_g_prompt or use_e_prompt): + attn_layer = Attention + elif not (use_prefix_tune_for_g_prompt or use_prefix_tune_for_e_prompt): + # Prompt tunning + attn_layer = Attention else: - return self._forward(x) - - -class VisionTransformer(nn.Module): - """ Vision Transformer - A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - - https://arxiv.org/abs/2010.11929 - """ + # Prefix tunning + attn_layer = PreT_Attention - def __init__( - self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', - embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None, - class_token=True, no_embed_class=False, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., - weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block, - prompt_length=None, embedding_key='cls', prompt_init='uniform', prompt_pool=False, prompt_key=False, pool_size=None, - top_k=None, batchwise_prompt=False, prompt_key_init='uniform', head_type='token', use_prompt_mask=False, - use_g_prompt=False, g_prompt_length=None, g_prompt_layer_idx=None, use_prefix_tune_for_g_prompt=False, - use_e_prompt=False, e_prompt_layer_idx=None, use_prefix_tune_for_e_prompt=False, same_key_value=False,): - """ - Args: - img_size (int, tuple): input image size - patch_size (int, tuple): patch size - in_chans (int): number of input channels - num_classes (int): number of classes for classification head - global_pool (str): type of global pooling for final sequence (default: 'token') - embed_dim (int): embedding dimension - depth (int): depth of transformer - num_heads (int): number of attention heads - mlp_ratio (int): ratio of mlp hidden dim to embedding dim - qkv_bias (bool): enable bias for qkv if True - init_values: (float): layer-scale init values - class_token (bool): use class token - fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) - drop_rate (float): dropout rate - attn_drop_rate (float): attention dropout rate - drop_path_rate (float): stochastic depth rate - weight_init (str): weight init scheme - embed_layer (nn.Module): patch embedding layer - norm_layer: (nn.Module): normalization layer - act_layer: (nn.Module): MLP activation layer - block_fn: (nn.Module): transformer block - prompt_pool (bool): use prompt pool or not - """ - super().__init__() - assert global_pool in ('', 'avg', 'token') - assert class_token or global_pool != 'token' - use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm - norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) - act_layer = act_layer or nn.GELU - - self.img_size = img_size - self.num_classes = num_classes - self.global_pool = global_pool - self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models - self.class_token = class_token - self.num_prefix_tokens = 1 if class_token else 0 - self.no_embed_class = no_embed_class - self.grad_checkpointing = False - - self.patch_embed = embed_layer( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) - num_patches = self.patch_embed.num_patches - - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None - embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens - self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) - self.pos_drop = nn.Dropout(p=drop_rate) + super().__init__(args=args, attn_layer=attn_layer, **kwargs) self.prompt_pool = prompt_pool self.head_type = head_type self.use_prompt_mask = use_prompt_mask - self.use_g_prompt = use_g_prompt self.g_prompt_layer_idx = g_prompt_layer_idx + # num_g_prompt : The actual number of layers to which g-prompt is attached. # In official code, create as many layers as the total number of layers and select them based on the index num_g_prompt = len(self.g_prompt_layer_idx) if self.g_prompt_layer_idx is not None else 0 @@ -247,7 +57,7 @@ def __init__( if use_g_prompt and g_prompt_length is not None and len(g_prompt_layer_idx) != 0: if not use_prefix_tune_for_g_prompt: - g_prompt_shape = (num_g_prompt, g_prompt_length, embed_dim) + g_prompt_shape = (num_g_prompt, g_prompt_length, self.embed_dim) if prompt_init == 'zero': self.g_prompt = nn.Parameter(torch.zeros(g_prompt_shape)) elif prompt_init == 'uniform': @@ -255,7 +65,7 @@ def __init__( nn.init.uniform_(self.g_prompt, -1, 1) else: if same_key_value: - g_prompt_shape = (num_g_prompt, 1, g_prompt_length, num_heads, embed_dim // num_heads) + g_prompt_shape = (num_g_prompt, 1, g_prompt_length, self.num_heads, self.embed_dim // self.num_heads) if prompt_init == 'zero': self.g_prompt = nn.Parameter(torch.zeros(g_prompt_shape)) elif prompt_init == 'uniform': @@ -263,7 +73,7 @@ def __init__( nn.init.uniform_(self.g_prompt, -1, 1) self.g_prompt = self.g_prompt.repeat(1, 2, 1, 1, 1) else: - g_prompt_shape = (num_g_prompt, 2, g_prompt_length, num_heads, embed_dim // num_heads) + g_prompt_shape = (num_g_prompt, 2, g_prompt_length, self.num_heads, self.embed_dim // self.num_heads) if prompt_init == 'zero': self.g_prompt = nn.Parameter(torch.zeros(g_prompt_shape)) elif prompt_init == 'uniform': @@ -273,19 +83,10 @@ def __init__( self.g_prompt = None if use_e_prompt and e_prompt_layer_idx is not None: - self.e_prompt = EPrompt(length=prompt_length, embed_dim=embed_dim, embedding_key=embedding_key, prompt_init=prompt_init, + self.e_prompt = EPrompt(length=prompt_length, embed_dim=self.embed_dim, embedding_key=embedding_key, prompt_init=prompt_init, prompt_pool=prompt_pool, prompt_key=prompt_key, pool_size=pool_size, top_k=top_k, batchwise_prompt=batchwise_prompt, prompt_key_init=prompt_key_init, num_layers=num_e_prompt, use_prefix_tune_for_e_prompt=use_prefix_tune_for_e_prompt, - num_heads=num_heads, same_key_value=same_key_value) - - if not (use_g_prompt or use_e_prompt): - attn_layer = Attention - elif not (use_prefix_tune_for_g_prompt or use_prefix_tune_for_e_prompt): - # Prompt tunning - attn_layer = Attention - else: - # Prefix tunning - attn_layer = PreT_Attention + num_heads=self.num_heads, same_key_value=same_key_value) self.total_prompt_len = 0 if self.prompt_pool: @@ -294,20 +95,10 @@ def __init__( if not self.use_prefix_tune_for_e_prompt: self.total_prompt_len += prompt_length * top_k * len(self.e_prompt_layer_idx) - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule - self.blocks = nn.Sequential(*[ - block_fn( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, attn_layer=attn_layer) - for i in range(depth)]) - self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() - - # Classifier Head - self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() - if weight_init != 'skip': - self.init_weights(weight_init) + if self.weight_init != 'skip': + self.init_weights(self.weight_init) def init_weights(self, mode=''): assert mode in ('jax', 'jax_nlhb', 'moco', '') @@ -317,40 +108,6 @@ def init_weights(self, mode=''): nn.init.normal_(self.cls_token, std=1e-6) named_apply(get_init_weights_vit(mode, head_bias), self) - def _init_weights(self, m): - # this fn left here for compat with downstream users - init_weights_vit_timm(m) - - @torch.jit.ignore() - def load_pretrained(self, checkpoint_path, prefix=''): - _load_weights(self, checkpoint_path, prefix) - - @torch.jit.ignore - def no_weight_decay(self): - return {'pos_embed', 'cls_token', 'dist_token'} - - @torch.jit.ignore - def group_matcher(self, coarse=False): - return dict( - stem=r'^cls_token|pos_embed|patch_embed', # stem and embed - blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] - ) - - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - self.grad_checkpointing = enable - - @torch.jit.ignore - def get_classifier(self): - return self.head - - def reset_classifier(self, num_classes: int, global_pool=None): - self.num_classes = num_classes - if global_pool is not None: - assert global_pool in ('', 'avg', 'token') - self.global_pool = global_pool - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() - def forward_features(self, x, task_id=-1, cls_features=None, train=False): x = self.patch_embed(x) @@ -359,53 +116,50 @@ def forward_features(self, x, task_id=-1, cls_features=None, train=False): x = self.pos_drop(x + self.pos_embed) - if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(self.blocks, x) - else: - if self.use_g_prompt or self.use_e_prompt: - if self.use_prompt_mask and train: - start = task_id * self.e_prompt.top_k - end = (task_id + 1) * self.e_prompt.top_k - single_prompt_mask = torch.arange(start, end).to(x.device) - prompt_mask = single_prompt_mask.unsqueeze(0).expand(x.shape[0], -1) - if end > self.e_prompt.pool_size: - prompt_mask = None - else: + if self.use_g_prompt or self.use_e_prompt: + if self.use_prompt_mask and train: + start = task_id * self.e_prompt.top_k + end = (task_id + 1) * self.e_prompt.top_k + single_prompt_mask = torch.arange(start, end).to(x.device) + prompt_mask = single_prompt_mask.unsqueeze(0).expand(x.shape[0], -1) + if end > self.e_prompt.pool_size: prompt_mask = None + else: + prompt_mask = None + + g_prompt_counter = -1 + e_prompt_counter = -1 - g_prompt_counter = -1 - e_prompt_counter = -1 - - res = self.e_prompt(x, prompt_mask=prompt_mask, cls_features=cls_features) - e_prompt = res['batched_prompt'] - - for i, block in enumerate(self.blocks): - if i in self.g_prompt_layer_idx: - if self.use_prefix_tune_for_g_prompt: - g_prompt_counter += 1 - # Prefix tunning, [B, 2, g_prompt_length, num_heads, embed_dim // num_heads] - idx = torch.tensor([g_prompt_counter] * x.shape[0]).to(x.device) - g_prompt = self.g_prompt[idx] - else: - g_prompt = None - x = block(x, prompt=g_prompt) - - elif i in self.e_prompt_layer_idx: - e_prompt_counter += 1 - if self.use_prefix_tune_for_e_prompt: - # Prefix tunning, [B, 2, top_k * e_prompt_length, num_heads, embed_dim // num_heads] - x = block(x, prompt=e_prompt[e_prompt_counter]) - else: - # Pommpt tunning, [B, top_k * e_prompt_length, embed_dim] - prompt = e_prompt[e_prompt_counter] - x = torch.cat([prompt, x], dim=1) - x = block(x) + res = self.e_prompt(x, prompt_mask=prompt_mask, cls_features=cls_features) + e_prompt = res['batched_prompt'] + + for i, block in enumerate(self.blocks): + if i in self.g_prompt_layer_idx: + if self.use_prefix_tune_for_g_prompt: + g_prompt_counter += 1 + # Prefix tunning, [B, 2, g_prompt_length, num_heads, embed_dim // num_heads] + idx = torch.tensor([g_prompt_counter] * x.shape[0]).to(x.device) + g_prompt = self.g_prompt[idx] + else: + g_prompt = None + x = block(x, prompt=g_prompt) + + elif i in self.e_prompt_layer_idx: + e_prompt_counter += 1 + if self.use_prefix_tune_for_e_prompt: + # Prefix tunning, [B, 2, top_k * e_prompt_length, num_heads, embed_dim // num_heads] + x = block(x, prompt=e_prompt[e_prompt_counter]) else: + # Pommpt tunning, [B, top_k * e_prompt_length, embed_dim] + prompt = e_prompt[e_prompt_counter] + x = torch.cat([prompt, x], dim=1) x = block(x) - else: - x = self.blocks(x) + else: + x = block(x) + else: + x = self.blocks(x) - res = dict() + res = dict() x = self.norm(x) res['x'] = x @@ -444,143 +198,6 @@ def forward(self, x, task_id=-1, cls_features=None, train=False): return res -def init_weights_vit_timm(module: nn.Module, name: str = ''): - """ ViT weight initialization, original timm impl (for reproducibility) """ - if isinstance(module, nn.Linear): - trunc_normal_(module.weight, std=.02) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif hasattr(module, 'init_weights'): - module.init_weights() - - -def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.): - """ ViT weight initialization, matching JAX (Flax) impl """ - if isinstance(module, nn.Linear): - if name.startswith('head'): - nn.init.zeros_(module.weight) - nn.init.constant_(module.bias, head_bias) - else: - nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias) - elif isinstance(module, nn.Conv2d): - lecun_normal_(module.weight) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif hasattr(module, 'init_weights'): - module.init_weights() - - -def init_weights_vit_moco(module: nn.Module, name: str = ''): - """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """ - if isinstance(module, nn.Linear): - if 'qkv' in name: - # treat the weights of Q, K, V separately - val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1])) - nn.init.uniform_(module.weight, -val, val) - else: - nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif hasattr(module, 'init_weights'): - module.init_weights() - - -def get_init_weights_vit(mode='jax', head_bias: float = 0.): - if 'jax' in mode: - return partial(init_weights_vit_jax, head_bias=head_bias) - elif 'moco' in mode: - return init_weights_vit_moco - else: - return init_weights_vit_timm - - -@torch.no_grad() -def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): - """ Load weights from .npz checkpoints for official Google Brain Flax implementation - """ - import numpy as np - - def _n2p(w, t=True): - if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: - w = w.flatten() - if t: - if w.ndim == 4: - w = w.transpose([3, 2, 0, 1]) - elif w.ndim == 3: - w = w.transpose([2, 0, 1]) - elif w.ndim == 2: - w = w.transpose([1, 0]) - return torch.from_numpy(w) - - w = np.load(checkpoint_path) - if not prefix and 'opt/target/embedding/kernel' in w: - prefix = 'opt/target/' - - if hasattr(model.patch_embed, 'backbone'): - # hybrid - backbone = model.patch_embed.backbone - stem_only = not hasattr(backbone, 'stem') - stem = backbone if stem_only else backbone.stem - stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) - stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) - stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) - if not stem_only: - for i, stage in enumerate(backbone.stages): - for j, block in enumerate(stage.blocks): - bp = f'{prefix}block{i + 1}/unit{j + 1}/' - for r in range(3): - getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) - getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) - getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) - if block.downsample is not None: - block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) - block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) - block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) - embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) - else: - embed_conv_w = adapt_input_conv( - model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) - model.patch_embed.proj.weight.copy_(embed_conv_w) - model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) - model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) - pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) - if pos_embed_w.shape != model.pos_embed.shape: - pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights - pos_embed_w, - model.pos_embed, - getattr(model, 'num_prefix_tokens', 1), - model.patch_embed.grid_size - ) - model.pos_embed.copy_(pos_embed_w) - model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) - model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) - if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: - model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) - model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) - # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights - # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: - # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) - # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) - for i, block in enumerate(model.blocks.children()): - block_prefix = f'{prefix}Transformer/encoderblock_{i}/' - mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' - block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) - block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) - block.attn.qkv.weight.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) - block.attn.qkv.bias.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) - block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) - block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) - for r in range(2): - getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) - getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) - block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) - block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) - - def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()): # Rescale the grid of position embeddings when loading from state_dict. Adapted from # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 @@ -639,40 +256,10 @@ def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False): return out_dict -def _create_vision_transformer(variant, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') - - if 'flexi' in variant: - # FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed - # interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation. - _filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False) - else: - _filter_fn = checkpoint_filter_fn - - # FIXME attn pool (currently only in siglip) params removed if pool disabled, is there a better soln? - strict = True - if 'siglip' in variant and kwargs.get('global_pool', None) != 'map': - strict = False - - pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None)) - pretrained_cfg.custom_load = True - - return build_model_with_cfg( - VisionTransformer, - variant, - pretrained, - pretrained_cfg=pretrained_cfg, - pretrained_filter_fn=_filter_fn, - pretrained_strict=strict, - **kwargs, - ) - - def vit_base_patch16_224_dualprompt(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) + model = create_vision_transformer('vit_base_patch16_224', base_class=VisionTransformer, filter_fn=checkpoint_filter_fn, pretrained=pretrained, **model_kwargs) return model diff --git a/models/er_ace.py b/models/er_ace.py index 28673296..4f115efa 100644 --- a/models/er_ace.py +++ b/models/er_ace.py @@ -51,7 +51,7 @@ def observe(self, inputs, labels, not_aug_inputs, epoch=None): self.args.minibatch_size, transform=self.transform, device=self.device) loss_re = self.loss(self.net(buf_inputs), buf_labels) - loss += loss_re + loss += loss_re loss.backward() self.opt.step() diff --git a/models/gdumb_lider.py b/models/gdumb_lider.py index 9f271847..74a13ab2 100644 --- a/models/gdumb_lider.py +++ b/models/gdumb_lider.py @@ -7,7 +7,6 @@ from utils.status import progress_bar - def fit_buffer(self: LiderOptimizer, epochs): optimizer = SGD(self.get_parameters(), lr=self.args.maxlr, momentum=self.args.optim_mom, weight_decay=self.args.optim_wd, nesterov=self.args.optim_nesterov) scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=1, T_mult=2, eta_min=self.args.minlr) diff --git a/models/l2p_utils/vit_prompt.py b/models/l2p_utils/vit_prompt.py index 19328def..3d74ed15 100644 --- a/models/l2p_utils/vit_prompt.py +++ b/models/l2p_utils/vit_prompt.py @@ -35,221 +35,28 @@ import torch.utils.checkpoint from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from timm.models.helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv, checkpoint_seq +from timm.models.helpers import named_apply, adapt_input_conv, checkpoint_seq from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ +from backbone.vit import create_vision_transformer, VisionTransformer as MammothVP from models.l2p_utils.prompt import Prompt _logger = logging.getLogger(__name__) -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, - 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', - **kwargs - } - - -class Attention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): - super().__init__() - assert dim % num_heads == 0, 'dim should be divisible by num_heads' - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = head_dim ** -0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x): - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) - - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -class LayerScale(nn.Module): - def __init__(self, dim, init_values=1e-5, inplace=False): - super().__init__() - self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim)) - - def forward(self, x): - return x.mul_(self.gamma) if self.inplace else x * self.gamma - - -class Block(nn.Module): - - def __init__( - self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.norm1 = norm_layer(dim) - self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) - self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() - # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here - self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - - self.norm2 = norm_layer(dim) - self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) - self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() - self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - - def forward(self, x): - x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) - x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) - return x - - -class ResPostBlock(nn.Module): - +class VisionTransformer(MammothVP): def __init__( - self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.init_values = init_values - - self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) - self.norm1 = norm_layer(dim) - self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - - self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) - self.norm2 = norm_layer(dim) - self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self, prompt_length=None, embedding_key='cls', prompt_init='uniform', prompt_pool=False, prompt_key=False, pool_size=None, + top_k=None, batchwise_prompt=False, prompt_key_init='uniform', head_type='token', use_prompt_mask=False, prompt_shuffle=False, args=None, **kwargs): - self.init_weights() + super().__init__(args=args, **kwargs) - def init_weights(self): - # NOTE this init overrides that base model init with specific changes for the block type - if self.init_values is not None: - nn.init.constant_(self.norm1.weight, self.init_values) - nn.init.constant_(self.norm2.weight, self.init_values) - - def forward(self, x): - x = x + self.drop_path1(self.norm1(self.attn(x))) - x = x + self.drop_path2(self.norm2(self.mlp(x))) - return x - - -class ParallelBlock(nn.Module): - - def __init__( - self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False, init_values=None, - drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.num_parallel = num_parallel - self.attns = nn.ModuleList() - self.ffns = nn.ModuleList() - for _ in range(num_parallel): - self.attns.append(nn.Sequential(OrderedDict([ - ('norm', norm_layer(dim)), - ('attn', Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)), - ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), - ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) - ]))) - self.ffns.append(nn.Sequential(OrderedDict([ - ('norm', norm_layer(dim)), - ('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)), - ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), - ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) - ]))) - - def _forward_jit(self, x): - x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0) - x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0) - return x - - @torch.jit.ignore - def _forward(self, x): - x = x + sum(attn(x) for attn in self.attns) - x = x + sum(ffn(x) for ffn in self.ffns) - return x - - def forward(self, x): - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return self._forward_jit(x) - else: - return self._forward(x) + self.num_prefix_tokens = 1 if self.class_token else 0 - -class VisionTransformer(nn.Module): - """ Vision Transformer - A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - - https://arxiv.org/abs/2010.11929 - """ - - def __init__( - self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', - embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None, - class_token=True, no_embed_class=False, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., - weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block, - prompt_length=None, embedding_key='cls', prompt_init='uniform', prompt_pool=False, prompt_key=False, pool_size=None, - top_k=None, batchwise_prompt=False, prompt_key_init='uniform', head_type='token', use_prompt_mask=False, prompt_shuffle=False): - """ - Args: - img_size (int, tuple): input image size - patch_size (int, tuple): patch size - in_chans (int): number of input channels - num_classes (int): number of classes for classification head - global_pool (str): type of global pooling for final sequence (default: 'token') - embed_dim (int): embedding dimension - depth (int): depth of transformer - num_heads (int): number of attention heads - mlp_ratio (int): ratio of mlp hidden dim to embedding dim - qkv_bias (bool): enable bias for qkv if True - init_values: (float): layer-scale init values - class_token (bool): use class token - fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) - drop_rate (float): dropout rate - attn_drop_rate (float): attention dropout rate - drop_path_rate (float): stochastic depth rate - weight_init (str): weight init scheme - embed_layer (nn.Module): patch embedding layer - norm_layer: (nn.Module): normalization layer - act_layer: (nn.Module): MLP activation layer - block_fn: (nn.Module): transformer block - prompt_pool (bool): use prompt pool or not - """ - super().__init__() - assert global_pool in ('', 'avg', 'token') - assert class_token or global_pool != 'token' - use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm - norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) - act_layer = act_layer or nn.GELU - - self.img_size = img_size - self.num_classes = num_classes - self.global_pool = global_pool - self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models - self.class_token = class_token - self.num_prefix_tokens = 1 if class_token else 0 - self.no_embed_class = no_embed_class - self.grad_checkpointing = False - - self.patch_embed = embed_layer( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) - num_patches = self.patch_embed.num_patches - - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None - embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens + embed_len = self.pos_embed.shape[1] if prompt_length is not None and pool_size is not None and prompt_pool: embed_len += prompt_length * top_k - self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) - self.pos_drop = nn.Dropout(p=drop_rate) + self.pos_embed = nn.Parameter(torch.randn(1, embed_len, self.embed_dim) * .02) self.prompt_pool = prompt_pool self.head_type = head_type @@ -257,66 +64,15 @@ def __init__( self.prompt_shuffle = prompt_shuffle if prompt_length is not None and pool_size is not None and prompt_pool: - self.prompt = Prompt(length=prompt_length, embed_dim=embed_dim, embedding_key=embedding_key, prompt_init=prompt_init, + self.prompt = Prompt(length=prompt_length, embed_dim=self.embed_dim, embedding_key=embedding_key, prompt_init=prompt_init, prompt_pool=prompt_pool, prompt_key=prompt_key, pool_size=pool_size, top_k=top_k, batchwise_prompt=batchwise_prompt, prompt_key_init=prompt_key_init, prompt_shuffle=self.prompt_shuffle) - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule - self.blocks = nn.Sequential(*[ - block_fn( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) - for i in range(depth)]) - self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() - # Classifier Head - self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() - - if weight_init != 'skip': - self.init_weights(weight_init) + self.head = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() - def init_weights(self, mode=''): - assert mode in ('jax', 'jax_nlhb', 'moco', '') - head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. - trunc_normal_(self.pos_embed, std=.02) - if self.cls_token is not None: - nn.init.normal_(self.cls_token, std=1e-6) - named_apply(get_init_weights_vit(mode, head_bias), self) - - def _init_weights(self, m): - # this fn left here for compat with downstream users - init_weights_vit_timm(m) - - @torch.jit.ignore() - def load_pretrained(self, checkpoint_path, prefix=''): - _load_weights(self, checkpoint_path, prefix) - - @torch.jit.ignore - def no_weight_decay(self): - return {'pos_embed', 'cls_token', 'dist_token'} - - @torch.jit.ignore - def group_matcher(self, coarse=False): - return dict( - stem=r'^cls_token|pos_embed|patch_embed', # stem and embed - blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] - ) - - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - self.grad_checkpointing = enable - - @torch.jit.ignore - def get_classifier(self): - return self.head - - def reset_classifier(self, num_classes: int, global_pool=None): - self.num_classes = num_classes - if global_pool is not None: - assert global_pool in ('', 'avg', 'token') - self.global_pool = global_pool - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + if self.weight_init != 'skip': + self.init_weights(self.weight_init) def forward_features(self, x, task_id=-1, cls_features=None, train=False): x = self.patch_embed(x) @@ -341,10 +97,7 @@ def forward_features(self, x, task_id=-1, cls_features=None, train=False): x = self.pos_drop(x + self.pos_embed) - if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(self.blocks, x) - else: - x = self.blocks(x) + x = self.blocks(x) x = self.norm(x) res['x'] = x @@ -374,147 +127,17 @@ def forward_head(self, res, pre_logits: bool = False): return res - def forward(self, x, task_id=-1, cls_features=None, train=False): - res = self.forward_features(x, task_id=task_id, cls_features=cls_features, train=train) - res = self.forward_head(res) - return res - + def forward(self, x, task_id=-1, cls_features=None, train=False, returnt='out'): + assert returnt in ('out', 'features', 'both') -def init_weights_vit_timm(module: nn.Module, name: str = ''): - """ ViT weight initialization, original timm impl (for reproducibility) """ - if isinstance(module, nn.Linear): - trunc_normal_(module.weight, std=.02) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif hasattr(module, 'init_weights'): - module.init_weights() + feats = self.forward_features(x, task_id=task_id, cls_features=cls_features, train=train) + if returnt == 'features': + return feats - -def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.): - """ ViT weight initialization, matching JAX (Flax) impl """ - if isinstance(module, nn.Linear): - if name.startswith('head'): - nn.init.zeros_(module.weight) - nn.init.constant_(module.bias, head_bias) - else: - nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias) - elif isinstance(module, nn.Conv2d): - lecun_normal_(module.weight) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif hasattr(module, 'init_weights'): - module.init_weights() - - -def init_weights_vit_moco(module: nn.Module, name: str = ''): - """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """ - if isinstance(module, nn.Linear): - if 'qkv' in name: - # treat the weights of Q, K, V separately - val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1])) - nn.init.uniform_(module.weight, -val, val) - else: - nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif hasattr(module, 'init_weights'): - module.init_weights() - - -def get_init_weights_vit(mode='jax', head_bias: float = 0.): - if 'jax' in mode: - return partial(init_weights_vit_jax, head_bias=head_bias) - elif 'moco' in mode: - return init_weights_vit_moco - else: - return init_weights_vit_timm - - -@torch.no_grad() -def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): - """ Load weights from .npz checkpoints for official Google Brain Flax implementation - """ - import numpy as np - - def _n2p(w, t=True): - if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: - w = w.flatten() - if t: - if w.ndim == 4: - w = w.transpose([3, 2, 0, 1]) - elif w.ndim == 3: - w = w.transpose([2, 0, 1]) - elif w.ndim == 2: - w = w.transpose([1, 0]) - return torch.from_numpy(w) - - w = np.load(checkpoint_path) - if not prefix and 'opt/target/embedding/kernel' in w: - prefix = 'opt/target/' - - if hasattr(model.patch_embed, 'backbone'): - # hybrid - backbone = model.patch_embed.backbone - stem_only = not hasattr(backbone, 'stem') - stem = backbone if stem_only else backbone.stem - stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) - stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) - stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) - if not stem_only: - for i, stage in enumerate(backbone.stages): - for j, block in enumerate(stage.blocks): - bp = f'{prefix}block{i + 1}/unit{j + 1}/' - for r in range(3): - getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) - getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) - getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) - if block.downsample is not None: - block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) - block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) - block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) - embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) - else: - embed_conv_w = adapt_input_conv( - model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) - model.patch_embed.proj.weight.copy_(embed_conv_w) - model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) - model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) - pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) - if pos_embed_w.shape != model.pos_embed.shape: - pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights - pos_embed_w, - model.pos_embed, - getattr(model, 'num_prefix_tokens', 1), - model.patch_embed.grid_size - ) - model.pos_embed.copy_(pos_embed_w) - model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) - model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) - if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: - model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) - model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) - # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights - # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: - # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) - # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) - for i, block in enumerate(model.blocks.children()): - block_prefix = f'{prefix}Transformer/encoderblock_{i}/' - mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' - block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) - block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) - block.attn.qkv.weight.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) - block.attn.qkv.bias.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) - block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) - block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) - for r in range(2): - getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) - getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) - block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) - block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) + res = self.forward_head(feats) + if returnt == 'both': + return res, feats + return res def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()): @@ -575,34 +198,10 @@ def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False): return out_dict -def _create_vision_transformer(variant, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') - - if 'flexi' in variant: - # FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed - # interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation. - _filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False) - else: - _filter_fn = checkpoint_filter_fn - - pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None)) - pretrained_cfg.custom_load = True - - return build_model_with_cfg( - VisionTransformer, - variant, - pretrained, - pretrained_cfg=pretrained_cfg, - pretrained_filter_fn=_filter_fn, - **kwargs, - ) - - def vit_base_patch16_224_l2p(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) + model = create_vision_transformer('vit_base_patch16_224', base_class=VisionTransformer, filter_fn=checkpoint_filter_fn, pretrained=pretrained, **model_kwargs) return model diff --git a/models/pnn.py b/models/pnn.py index 0168e15d..668d3c20 100644 --- a/models/pnn.py +++ b/models/pnn.py @@ -17,7 +17,7 @@ def get_backbone(bone, old_cols=None, x_shape=None): from backbone.MNISTMLP import MNISTMLP from backbone.MNISTMLP_PNN import MNISTMLP_PNN - from backbone.ResNet18 import ResNet + from backbone.ResNetBlock import ResNet from backbone.ResNet18_PNN import resnet18_pnn if isinstance(bone, MNISTMLP): diff --git a/models/slca.py b/models/slca.py index 177e116a..9647e187 100644 --- a/models/slca.py +++ b/models/slca.py @@ -51,7 +51,7 @@ def __init__(self, backbone, loss, args, transform): print("-" * 20) args.milestones = args.milestones.split(',') - n_features = backbone._network.feature_dim + n_features = backbone._network.convnet.feature_dim super().__init__(backbone, loss, args, transform) self.class_means = torch.zeros(self.num_classes, n_features).to(self.device) self.class_covs = torch.zeros(self.num_classes, n_features, n_features).to(self.device) diff --git a/models/slca_utils/convs/resnet.py b/models/slca_utils/convs/resnet.py deleted file mode 100644 index 4abd04f6..00000000 --- a/models/slca_utils/convs/resnet.py +++ /dev/null @@ -1,362 +0,0 @@ -''' -Reference: -https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py -''' -import torch -import torch.nn as nn -# from torchvision.models.utils import load_state_dict_from_url - - -__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', - 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', - 'wide_resnet50_2', 'wide_resnet101_2'] - - -model_urls = { - 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', - 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', - 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', - 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', - 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', - 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', - 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', - 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', - 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', -} - - -def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): - """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=dilation, groups=groups, bias=False, dilation=dilation) - - -def conv1x1(in_planes, out_planes, stride=1): - """1x1 convolution""" - return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) - - -class BasicBlock(nn.Module): - expansion = 1 - __constants__ = ['downsample'] - - def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, - base_width=64, dilation=1, norm_layer=None, no_last_relu=False): - super(BasicBlock, self).__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2d - if groups != 1 or base_width != 64: - raise ValueError('BasicBlock only supports groups=1 and base_width=64') - if dilation > 1: - raise NotImplementedError("Dilation > 1 not supported in BasicBlock") - # Both self.conv1 and self.downsample layers downsample the input when stride != 1 - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = norm_layer(planes) - self.relu = nn.ReLU(inplace=True) - self.conv2 = conv3x3(planes, planes) - self.bn2 = norm_layer(planes) - self.downsample = downsample - self.stride = stride - self.no_last_relu = no_last_relu - - def forward(self, x): - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - if not self.no_last_relu: - out = self.relu(out) - - return out - - -class Bottleneck(nn.Module): - expansion = 4 - __constants__ = ['downsample'] - - def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, - base_width=64, dilation=1, norm_layer=None, no_last_relu=False): - super(Bottleneck, self).__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2d - width = int(planes * (base_width / 64.)) * groups - # Both self.conv2 and self.downsample layers downsample the input when stride != 1 - self.conv1 = conv1x1(inplanes, width) - self.bn1 = norm_layer(width) - self.conv2 = conv3x3(width, width, stride, groups, dilation) - self.bn2 = norm_layer(width) - self.conv3 = conv1x1(width, planes * self.expansion) - self.bn3 = norm_layer(planes * self.expansion) - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - self.no_last_relu = no_last_relu - - def forward(self, x): - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - if not self.no_last_relu: - out = self.relu(out) - - return out - - -# 修改Resnet的实现。 -class ResNet(nn.Module): - - def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, - groups=1, width_per_group=64, replace_stride_with_dilation=None, - norm_layer=None, cifar=False, no_last_relu=False): - super(ResNet, self).__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2d - self._norm_layer = norm_layer - self.cifar = cifar - - self.inplanes = 64 - self.dilation = 1 - if replace_stride_with_dilation is None: - # each element in the tuple indicates if we should replace - # the 2x2 stride with a dilated convolution instead - replace_stride_with_dilation = [False, False, False] - if len(replace_stride_with_dilation) != 3: - raise ValueError("replace_stride_with_dilation should be None " - "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) - self.groups = groups - self.base_width = width_per_group - if self.cifar: - self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) - else: - self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) - self.bn1 = norm_layer(self.inplanes) - self.relu = nn.ReLU(inplace=True) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Removed in _forward_impl for cifar - self.layer1 = self._make_layer(block, 64, layers[0]) - self.layer2 = self._make_layer(block, 128, layers[1], stride=2, - dilate=replace_stride_with_dilation[0]) - self.layer3 = self._make_layer(block, 256, layers[2], stride=2, - dilate=replace_stride_with_dilation[1]) - self.layer4 = self._make_layer(block, 512, layers[3], stride=2, - dilate=replace_stride_with_dilation[2], no_last_relu=no_last_relu) - self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.out_dim = 512 * block.expansion - # self.fc = nn.Linear(512 * block.expansion, num_classes) # Removed in _forward_impl - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - # Zero-initialize the last BN in each residual branch, - # so that the residual branch starts with zeros, and each residual block behaves like an identity. - # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 - if zero_init_residual: - for m in self.modules(): - if isinstance(m, Bottleneck): - nn.init.constant_(m.bn3.weight, 0) - elif isinstance(m, BasicBlock): - nn.init.constant_(m.bn2.weight, 0) - - def _make_layer(self, block, planes, blocks, stride=1, dilate=False, no_last_relu=False): - norm_layer = self._norm_layer - downsample = None - previous_dilation = self.dilation - if dilate: - self.dilation *= stride - stride = 1 - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - conv1x1(self.inplanes, planes * block.expansion, stride), - norm_layer(planes * block.expansion), - ) - - layers = [] - layers.append(block(self.inplanes, planes, stride, downsample, self.groups, - self.base_width, previous_dilation, norm_layer)) - self.inplanes = planes * block.expansion - for bid in range(1, blocks): - layers.append(block(self.inplanes, planes, groups=self.groups, - base_width=self.base_width, dilation=self.dilation, - norm_layer=norm_layer, no_last_relu=no_last_relu if bid == blocks - 1 else False)) - - return nn.Sequential(*layers) - - def _forward_impl(self, x): - # See note [TorchScript super()] - x = self.conv1(x) # [bs, 64, 32, 32] - x = self.bn1(x) - x = self.relu(x) - if not self.cifar: - x = self.maxpool(x) - - x_1 = self.layer1(x) # [bs, 128, 32, 32] - x_2 = self.layer2(x_1) # [bs, 256, 16, 16] - x_3 = self.layer3(x_2) # [bs, 512, 8, 8] - x_4 = self.layer4(x_3) # [bs, 512, 4, 4] - - pooled = self.avgpool(x_4) # [bs, 512, 1, 1] - features = torch.flatten(pooled, 1) # [bs, 512] - # x = self.fc(x) - - return { - 'fmaps': [x_1, x_2, x_3, x_4], - 'features': features - } - - def forward(self, x): - return self._forward_impl(x) - - @property - def last_conv(self): - if hasattr(self.layer4[-1], 'conv3'): - return self.layer4[-1].conv3 - else: - return self.layer4[-1].conv2 - - -def _resnet(arch, block, layers, pretrained, progress, **kwargs): - model = ResNet(block, layers, **kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls[arch], - progress=progress) - model.load_state_dict(state_dict) - return model - - -def resnet18(pretrained=False, progress=True, **kwargs): - r"""ResNet-18 model from - `"Deep Residual Learning for Image Recognition" `_ - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, - **kwargs) - - -def resnet34(pretrained=False, progress=True, **kwargs): - r"""ResNet-34 model from - `"Deep Residual Learning for Image Recognition" `_ - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, - **kwargs) - - -def resnet50(pretrained=False, progress=True, **kwargs): - r"""ResNet-50 model from - `"Deep Residual Learning for Image Recognition" `_ - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, - **kwargs) - - -def resnet101(pretrained=False, progress=True, **kwargs): - r"""ResNet-101 model from - `"Deep Residual Learning for Image Recognition" `_ - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, - **kwargs) - - -def resnet152(pretrained=False, progress=True, **kwargs): - r"""ResNet-152 model from - `"Deep Residual Learning for Image Recognition" `_ - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, - **kwargs) - - -def resnext50_32x4d(pretrained=False, progress=True, **kwargs): - r"""ResNeXt-50 32x4d model from - `"Aggregated Residual Transformation for Deep Neural Networks" `_ - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - kwargs['groups'] = 32 - kwargs['width_per_group'] = 4 - return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], - pretrained, progress, **kwargs) - - -def resnext101_32x8d(pretrained=False, progress=True, **kwargs): - r"""ResNeXt-101 32x8d model from - `"Aggregated Residual Transformation for Deep Neural Networks" `_ - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - kwargs['groups'] = 32 - kwargs['width_per_group'] = 8 - return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], - pretrained, progress, **kwargs) - - -def wide_resnet50_2(pretrained=False, progress=True, **kwargs): - r"""Wide ResNet-50-2 model from - `"Wide Residual Networks" `_ - The model is the same as ResNet except for the bottleneck number of channels - which is twice larger in every block. The number of channels in outer 1x1 - convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 - channels, and in Wide ResNet-50-2 has 2048-1024-2048. - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - kwargs['width_per_group'] = 64 * 2 - return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], - pretrained, progress, **kwargs) - - -def wide_resnet101_2(pretrained=False, progress=True, **kwargs): - r"""Wide ResNet-101-2 model from - `"Wide Residual Networks" `_ - The model is the same as ResNet except for the bottleneck number of channels - which is twice larger in every block. The number of channels in outer 1x1 - convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 - channels, and in Wide ResNet-50-2 has 2048-1024-2048. - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - kwargs['width_per_group'] = 64 * 2 - return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], - pretrained, progress, **kwargs) diff --git a/models/slca_utils/convs/vits.py b/models/slca_utils/convs/vits.py deleted file mode 100644 index 3615ee94..00000000 --- a/models/slca_utils/convs/vits.py +++ /dev/null @@ -1,689 +0,0 @@ -""" Vision Transformer (ViT) in PyTorch - -A PyTorch implement of Vision Transformers as described in: - -'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - - https://arxiv.org/abs/2010.11929 - -`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` - - https://arxiv.org/abs/2106.10270 - -The official jax code is released and available at https://github.com/google-research/vision_transformer - -DeiT model defs and weights from https://github.com/facebookresearch/deit, -paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 - -Acknowledgments: -* The paper authors for releasing code and weights, thanks! -* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out -for some einops/einsum fun -* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT -* Bert reference code checks against Huggingface Transformers and Tensorflow Bert - -Hacked together by / Copyright 2020, Ross Wightman -""" -import math -import logging -from functools import partial -from collections import OrderedDict - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv, resolve_pretrained_cfg -from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ - -_logger = logging.getLogger(__name__) - - -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, - 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', - **kwargs - } - - -default_cfgs = { - # patch models (weights from official Google JAX impl) - 'vit_tiny_patch16_224': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), - 'vit_tiny_patch16_384': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', - input_size=(3, 384, 384), crop_pct=1.0), - 'vit_small_patch32_224': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), - 'vit_small_patch32_384': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', - input_size=(3, 384, 384), crop_pct=1.0), - 'vit_small_patch16_224': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), - 'vit_small_patch16_384': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', - input_size=(3, 384, 384), crop_pct=1.0), - 'vit_base_patch32_224': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), - 'vit_base_patch32_384': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', - input_size=(3, 384, 384), crop_pct=1.0), - 'vit_base_patch16_224': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), - 'vit_base_patch16_384': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', - input_size=(3, 384, 384), crop_pct=1.0), - 'vit_base_patch8_224': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), - 'vit_large_patch32_224': _cfg( - url='', # no official model weights for this combo, only for in21k - ), - 'vit_large_patch32_384': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', - input_size=(3, 384, 384), crop_pct=1.0), - 'vit_large_patch16_224': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'), - 'vit_large_patch16_384': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', - input_size=(3, 384, 384), crop_pct=1.0), - - 'vit_huge_patch14_224': _cfg(url=''), - 'vit_giant_patch14_224': _cfg(url=''), - 'vit_gigantic_patch14_224': _cfg(url=''), - - # patch models, imagenet21k (weights from official Google JAX impl) - 'vit_tiny_patch16_224_in21k': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', - num_classes=21843), - 'vit_small_patch32_224_in21k': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', - num_classes=21843), - 'vit_small_patch16_224_in21k': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', - num_classes=21843), - 'vit_base_patch32_224_in21k': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz', - num_classes=21843), - 'vit_base_patch16_224_in21k': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', - # url='./B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', - num_classes=21843), - 'vit_base_patch8_224_in21k': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', - num_classes=21843), - 'vit_large_patch32_224_in21k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', - num_classes=21843), - 'vit_large_patch16_224_in21k': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', - num_classes=21843), - 'vit_huge_patch14_224_in21k': _cfg( - url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', - hf_hub='timm/vit_huge_patch14_224_in21k', - num_classes=21843), - - # SAM trained models (https://arxiv.org/abs/2106.01548) - 'vit_base_patch32_sam_224': _cfg( - url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'), - 'vit_base_patch16_sam_224': _cfg( - url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'), - - # deit models (FB weights) - 'deit_tiny_patch16_224': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), - 'deit_small_patch16_224': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), - 'deit_base_patch16_224': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), - 'deit_base_patch16_384': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0), - 'deit_tiny_distilled_patch16_224': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), - 'deit_small_distilled_patch16_224': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), - 'deit_base_distilled_patch16_224': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), - 'deit_base_distilled_patch16_384': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0, - classifier=('head', 'head_dist')), - - # ViT ImageNet-21K-P pretraining by MILL - 'vit_base_patch16_224_miil_in21k': _cfg( - url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth', - mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221, - ), - 'vit_base_patch16_224_miil': _cfg( - url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm' - '/vit_base_patch16_224_1k_miil_84_4.pth', - mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', - ), -} - - -class Attention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = head_dim ** -0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x): - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) - - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -class Block(nn.Module): - - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.norm1 = norm_layer(dim) - self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) - # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - def forward(self, x): - x = x + self.drop_path(self.attn(self.norm1(x))) - x = x + self.drop_path(self.mlp(self.norm2(x))) - return x - - -class VisionTransformer(nn.Module): - """ Vision Transformer - - A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - - https://arxiv.org/abs/2010.11929 - - Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` - - https://arxiv.org/abs/2012.12877 - """ - - def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, - num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, - act_layer=None, weight_init='', with_adapter=False, global_pool=False): - """ - Args: - img_size (int, tuple): input image size - patch_size (int, tuple): patch size - in_chans (int): number of input channels - num_classes (int): number of classes for classification head - embed_dim (int): embedding dimension - depth (int): depth of transformer - num_heads (int): number of attention heads - mlp_ratio (int): ratio of mlp hidden dim to embedding dim - qkv_bias (bool): enable bias for qkv if True - representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set - distilled (bool): model includes a distillation token and head as in DeiT models - drop_rate (float): dropout rate - attn_drop_rate (float): attention dropout rate - drop_path_rate (float): stochastic depth rate - embed_layer (nn.Module): patch embedding layer - norm_layer: (nn.Module): normalization layer - weight_init: (str): weight init scheme - """ - super().__init__() - self.num_classes = num_classes - self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models - self.out_dim = embed_dim - self.num_tokens = 2 if distilled else 1 - norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) - act_layer = act_layer or nn.GELU - - self.with_adapter = with_adapter - self.global_pool = global_pool - - self.patch_embed = embed_layer( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) - num_patches = self.patch_embed.num_patches - - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) - self.pos_drop = nn.Dropout(p=drop_rate) - - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule - self.blocks = nn.ModuleList([ - Block( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, - attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) - for i in range(depth)]) - self.norm = norm_layer(embed_dim) - - # Representation layer - if representation_size and not distilled: - self.num_features = representation_size - self.pre_logits = nn.Sequential(OrderedDict([ - ('fc', nn.Linear(embed_dim, representation_size)), - ('act', nn.Tanh()) - ])) - else: - self.pre_logits = nn.Identity() - - # Classifier head(s) - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - self.head_dist = None - if distilled: - self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() - - if self.with_adapter: - self.adp_layers = [] - for adp_i in range(4): - self.adp_layers.append(self.get_adapter(embed_dim)) - self.adp_layers = nn.ModuleList(self.adp_layers) - self.adp_norm = nn.LayerNorm(embed_dim) - self.extra_blocks = nn.ModuleList([]) - self.init_weights(weight_init) - if self.with_adapter: - for adp_i in range(4): - nn.init.constant_(self.adp_layers[adp_i][-2].bias, -2.19) - - def get_adapter(self, embed_dim): - return nn.Sequential( - nn.Linear(embed_dim, embed_dim * 3, bias=False), - nn.LayerNorm(embed_dim * 3), - nn.GELU(), - nn.Linear(embed_dim * 3, embed_dim, bias=False), - nn.LayerNorm(embed_dim), - nn.GELU(), - nn.Linear(embed_dim, embed_dim, bias=True), - nn.Sigmoid() - ) - - def init_weights(self, mode=''): - assert mode in ('jax', 'jax_nlhb', 'nlhb', '') - head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. - trunc_normal_(self.pos_embed, std=.02) - if self.dist_token is not None: - trunc_normal_(self.dist_token, std=.02) - if mode.startswith('jax'): - # leave cls token as zeros to match jax impl - named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self) - else: - trunc_normal_(self.cls_token, std=.02) - self.apply(_init_vit_weights) - - def _init_weights(self, m): - # this fn left here for compat with downstream users - _init_vit_weights(m) - - @torch.jit.ignore() - def load_pretrained(self, checkpoint_path, prefix=''): - _load_weights(self, checkpoint_path, prefix) - - @torch.jit.ignore - def no_weight_decay(self): - return {'pos_embed', 'cls_token', 'dist_token'} - - def get_classifier(self): - if self.dist_token is None: - return self.head - else: - return self.head, self.head_dist - - def reset_classifier(self, num_classes, global_pool=''): - self.num_classes = num_classes - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() - if self.num_tokens == 2: - self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() - - def forward_features(self, x, prompt=None, layer_feat=False): - img = x - x = self.patch_embed(x) - cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks - prompt_length = 0 - if self.dist_token is None and prompt is None: - x = torch.cat((cls_token, x), dim=1) - elif prompt is not None: - x = torch.cat((prompt, cls_token, x), dim=1) - prompt_length = prompt.size(1) - else: - x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) - x[:, prompt_length:] = self.pos_drop(x[:, prompt_length:] + self.pos_embed) - # x = self.blocks(x) - feats = [] - feats_l = [] - for b_id, block in enumerate(self.blocks): - x = block(x) - if self.with_adapter and (b_id + 1) % (len(self.blocks) // 4) == 0: - feats.append(x) - if layer_feat: - feats_l.append(x) - if b_id == len(self.blocks) - 2: - penultimate_feat = x.clone() - - if layer_feat: - return feats_l - - if len(self.extra_blocks) > 0: - assert not self.with_adapter - outs = [self.norm(x)[:, 0]] - for extra_block in self.extra_blocks: - outs.append(extra_block(penultimate_feat)[:, 0]) - return outs - - if self.with_adapter and self.training: - adp_inp = feats[-1][:, 0].detach() - masks = [] - for adp_i, adp_layer in enumerate(self.adp_layers): - m_ = adp_layer(adp_inp) - # if adp_i==0: - # m_ = m_.mean(1) - # m_ = torch.sigmoid(m_) - adp_inp = m_ * feats[adp_i][:, 0] + feats[adp_i][:, 0].detach() - masks.append(m_) - return adp_inp, torch.cat(masks, dim=1) - # return self.adp_norm(adp_inp.unsqueeze(1)).squeeze(1) - - if self.global_pool: - x = x[:, 1:, :].mean(dim=1) # global pool without cls token - return self.norm(x) - - x = self.norm(x) - if self.dist_token is None: - if prompt is not None: - return x[:, :prompt_length].mean(dim=1) - return self.pre_logits(x[:, 0]) - else: - return x[:, 0] # , x[:, 1] - - def forward(self, x, prompt=None, layer_feat=False): - x = self.forward_features(x, prompt, layer_feat) - if self.with_adapter and self.training: - x = {'masks': x[1], 'features': x[0]} - else: - x = {'features': x} - # if self.head_dist is not None: - # x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple - # if self.training and not torch.jit.is_scripting(): - # # during inference, return the average of both classifier predictions - # return x, x_dist - # else: - # return (x + x_dist) / 2 - # else: - # x = self.head(x) - return x - - -def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): - """ ViT weight initialization - * When called without n, head_bias, jax_impl args it will behave exactly the same - as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). - * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl - """ - if isinstance(module, nn.Linear): - if name.startswith('head'): - nn.init.zeros_(module.weight) - nn.init.constant_(module.bias, head_bias) - elif name.startswith('pre_logits'): - lecun_normal_(module.weight) - nn.init.zeros_(module.bias) - else: - if jax_impl: - nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - if 'mlp' in name: - nn.init.normal_(module.bias, std=1e-6) - else: - nn.init.zeros_(module.bias) - else: - trunc_normal_(module.weight, std=.02) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif jax_impl and isinstance(module, nn.Conv2d): - # NOTE conv was left to pytorch default in my original init - lecun_normal_(module.weight) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): - nn.init.zeros_(module.bias) - nn.init.ones_(module.weight) - - -@torch.no_grad() -def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): - """ Load weights from .npz checkpoints for official Google Brain Flax implementation - """ - import numpy as np - - def _n2p(w, t=True): - if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: - w = w.flatten() - if t: - if w.ndim == 4: - w = w.transpose([3, 2, 0, 1]) - elif w.ndim == 3: - w = w.transpose([2, 0, 1]) - elif w.ndim == 2: - w = w.transpose([1, 0]) - return torch.from_numpy(w) - - w = np.load(checkpoint_path) - if not prefix and 'opt/target/embedding/kernel' in w: - prefix = 'opt/target/' - - if hasattr(model.patch_embed, 'backbone'): - # hybrid - backbone = model.patch_embed.backbone - stem_only = not hasattr(backbone, 'stem') - stem = backbone if stem_only else backbone.stem - stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) - stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) - stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) - if not stem_only: - for i, stage in enumerate(backbone.stages): - for j, block in enumerate(stage.blocks): - bp = f'{prefix}block{i + 1}/unit{j + 1}/' - for r in range(3): - getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) - getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) - getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) - if block.downsample is not None: - block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) - block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) - block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) - embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) - else: - embed_conv_w = adapt_input_conv( - model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) - model.patch_embed.proj.weight.copy_(embed_conv_w) - model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) - model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) - pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) - if pos_embed_w.shape != model.pos_embed.shape: - pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights - pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) - model.pos_embed.copy_(pos_embed_w) - model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) - model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) - if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: - model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) - model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) - if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: - model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) - model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) - for i, block in enumerate(model.blocks.children()): - block_prefix = f'{prefix}Transformer/encoderblock_{i}/' - mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' - block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) - block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) - block.attn.qkv.weight.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) - block.attn.qkv.bias.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) - block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) - block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) - for r in range(2): - getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) - getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) - block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) - block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) - - -def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): - # Rescale the grid of position embeddings when loading from state_dict. Adapted from - # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 - _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) - ntok_new = posemb_new.shape[1] - if num_tokens: - posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] - ntok_new -= num_tokens - else: - posemb_tok, posemb_grid = posemb[:, :0], posemb[0] - gs_old = int(math.sqrt(len(posemb_grid))) - if not len(gs_new): # backwards compatibility - gs_new = [int(math.sqrt(ntok_new))] * 2 - assert len(gs_new) >= 2 - _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) - posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) - posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False) - posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) - posemb = torch.cat([posemb_tok, posemb_grid], dim=1) - return posemb - - -def checkpoint_filter_fn(state_dict, model): - """ convert patch embedding weight from manual patchify + linear proj to conv""" - out_dict = {} - if 'model' in state_dict: - # For deit models - state_dict = state_dict['model'] - for k, v in state_dict.items(): - if 'patch_embed.proj.weight' in k and len(v.shape) < 4: - # For old models that I trained prior to conv based patchification - O, I, H, W = model.patch_embed.proj.weight.shape - v = v.reshape(O, -1, H, W) - elif k == 'pos_embed' and v.shape != model.pos_embed.shape: - # To resize pos embedding when using model at different size from pretrained weights - v = resize_pos_embed( - v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) - out_dict[k] = v - return out_dict - - -# def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs): -# default_cfg = default_cfg or default_cfgs[variant] -# if kwargs.get('features_only', None): -# raise RuntimeError('features_only not implemented for Vision Transformer models.') - -# # NOTE this extra code to support handling of repr size for in21k pretrained models -# default_num_classes = default_cfg['num_classes'] -# num_classes = kwargs.get('num_classes', default_num_classes) -# repr_size = kwargs.pop('representation_size', None) -# if repr_size is not None and num_classes != default_num_classes: -# # Remove representation layer if fine-tuning. This may not always be the desired action, -# # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface? -# _logger.warning("Removing representation layer for fine-tuning.") -# repr_size = None - -# model = build_model_with_cfg( -# VisionTransformer, variant, pretrained, -# default_cfg=default_cfg, -# representation_size=repr_size, -# pretrained_filter_fn=checkpoint_filter_fn, -# pretrained_custom_load='npz' in default_cfg['url'], -# **kwargs) -# return model - -def _create_vision_transformer(variant, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') - - if 'flexi' in variant: - # FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed - # interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation. - _filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False) - else: - _filter_fn = checkpoint_filter_fn - - # FIXME attn pool (currently only in siglip) params removed if pool disabled, is there a better soln? - strict = True - if 'siglip' in variant and kwargs.get('global_pool', None) != 'map': - strict = False - - pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None) or default_cfgs[variant]) - pretrained_cfg.custom_load = True - - return build_model_with_cfg( - VisionTransformer, - variant, - pretrained, - pretrained_cfg=pretrained_cfg, - pretrained_filter_fn=_filter_fn, - pretrained_strict=strict, - **kwargs, - ) - - -def vit_base_patch16_224_in21k(pretrained=False, adapter=False, **kwargs): - """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). - ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. - NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer - """ - model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, with_adapter=adapter, **kwargs) - model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) - del model.head - del model.norm - model.norm = nn.LayerNorm(768) - return model - - -def vit_base_patch16_224_mocov3(pretrained=False, adapter=False, **kwargs): - """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). - ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. - NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer - """ - model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, with_adapter=adapter, **kwargs) - model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=False, **model_kwargs) - del model.head - ckpt = torch.load('mocov3-vit-base-300ep.pth', map_location='cpu')['model'] - state_dict = model.state_dict() - state_dict.update(ckpt) - model.load_state_dict(state_dict) - del model.norm - model.norm = nn.LayerNorm(768) - return model diff --git a/models/slca_utils/inc_net.py b/models/slca_utils/inc_net.py index ce5c49ef..ce791b64 100644 --- a/models/slca_utils/inc_net.py +++ b/models/slca_utils/inc_net.py @@ -1,11 +1,14 @@ import copy +import os +import sys import torch from torch import nn +import torch.nn.functional as F +from backbone.ResNetBlock import resnet18, resnet34 +from backbone.ResNetBottleneck import resnet50 +from backbone.vit import vit_base_patch16_224_prompt_prototype from models.slca_utils.convs.cifar_resnet import resnet32 -from models.slca_utils.convs.resnet import resnet18, resnet34, resnet50 from models.slca_utils.convs.linears import SimpleContinualLinear -from models.slca_utils.convs.vits import vit_base_patch16_224_in21k, vit_base_patch16_224_mocov3 -import torch.nn.functional as F def get_convnet(feature_extractor_type, pretrained=False): @@ -24,9 +27,25 @@ def get_convnet(feature_extractor_type, pretrained=False): return resnet50(pretrained=pretrained) elif name == 'vit-b-p16': print("Using ViT-B/16 pretrained on ImageNet21k (NO FINETUNE ON IN1K)") - return vit_base_patch16_224_in21k(pretrained=pretrained) + model = vit_base_patch16_224_prompt_prototype(pretrained=pretrained, pretrain_type='in21k', num_classes=0) + model.norm = nn.LayerNorm(model.embed_dim) # from the original implementation + return model elif name == 'vit-b-p16-mocov3': - return vit_base_patch16_224_mocov3(pretrained=True) + model = vit_base_patch16_224_prompt_prototype(pretrained=pretrained, pretrain_type='in21k', num_classes=0) + + del model.head + if not os.path.exists('mocov3-vit-base-300ep.pth'): + print("Cannot find the pretrained model for MoCoV3-ViT-B/16") + print("Please download the model from https://drive.google.com/file/d/1bshDu4jEKztZZvwpTVXSAuCsDoXwCkfy/view?usp=share_link") + sys.exit(1) + + ckpt = torch.load('mocov3-vit-base-300ep.pth', map_location='cpu')['model'] # from the original implementation + state_dict = model.state_dict() + state_dict.update(ckpt) + model.load_state_dict(state_dict) + del model.norm + model.norm = nn.LayerNorm(model.embed_dim) + return model else: raise NotImplementedError('Unknown type {}'.format(feature_extractor_type)) @@ -44,11 +63,11 @@ def feature_dim(self): return self.convnet.out_dim def extract_vector(self, x): - return self.convnet(x)['features'] + return self.convnet(x, returnt='features') def forward(self, x): - x = self.convnet(x) - out = self.fc(x['features']) + x = self.convnet(x, returnt='features') + out = self.fc(x) ''' { 'fmaps': [x_1, x_2, ..., x_n], @@ -56,7 +75,7 @@ def forward(self, x): 'logits': logits } ''' - out.update(x) + out.update({'features': x}) return out @@ -84,19 +103,9 @@ def __init__(self, feature_extractor_type, pretrained, fc_with_ln=False): self.old_fc = None self.fc_with_ln = fc_with_ln - def extract_layerwise_vector(self, x, pool=True): - with torch.no_grad(): - features = self.convnet(x, layer_feat=True)['features'] - for f_i in range(len(features)): - if pool: - features[f_i] = features[f_i].mean(1).cpu().numpy() - else: - features[f_i] = features[f_i][:, 0].cpu().numpy() - return features - def update_fc(self, nb_classes, freeze_old=True): if self.fc is None: - self.fc = self.generate_fc(self.feature_dim, nb_classes) + self.fc = self.generate_fc(self.convnet.feature_dim, nb_classes) else: self.fc.update(nb_classes, freeze_old=freeze_old) @@ -120,10 +129,10 @@ def forward(self, x, bcb_no_grad=False, fc_only=False): return fc_out if bcb_no_grad: with torch.no_grad(): - x = self.convnet(x) + x = self.convnet(x, returnt='features') else: - x = self.convnet(x) - out = self.fc(x['features']) - out.update(x) + x = self.convnet(x, returnt='features') + out = self.fc(x) + out.update({'features': x}) return out diff --git a/models/twf.py b/models/twf.py index 29d49068..14648535 100644 --- a/models/twf.py +++ b/models/twf.py @@ -9,6 +9,8 @@ from torchvision import transforms import torch.nn.functional as F +from utils.kornia_utils import KorniaMultiAug + def batch_iterate(size: int, batch_size: int): n_chunks = size // batch_size @@ -51,7 +53,7 @@ def __init__(self, backbone, loss, args, transform): backbone, loss, args, transform) self.buffer = Buffer(self.args.buffer_size) - self.buf_transform = self.get_custom_double_transform(self.transform.transforms) + self.buf_transform = self.get_custom_double_transform(self.original_transform.transforms) if self.args.loadcheck is None: print("Warning: no checkpoint loaded!") diff --git a/models/utils/continual_model.py b/models/utils/continual_model.py index c8a897fc..97fcfff8 100644 --- a/models/utils/continual_model.py +++ b/models/utils/continual_model.py @@ -138,6 +138,7 @@ def __init__(self, backbone: nn.Module, loss: nn.Module, self.net = backbone self.loss = loss self.args = args + self.original_transform = transform self.transform = transform self.dataset = get_dataset(self.args) self.N_CLASSES = self.dataset.N_CLASSES @@ -149,11 +150,10 @@ def __init__(self, backbone: nn.Module, loss: nn.Module, self._current_task = 0 try: - self.weak_transform = to_kornia_transform(transform.transforms[-1].transforms) + self.transform = to_kornia_transform(transform.transforms[-1].transforms) self.normalization_transform = to_kornia_transform(self.dataset.get_normalization_transform()) except BaseException: print("Warning: could not initialize kornia transforms.") - self.weak_transform = transforms.Compose([transforms.ToPILImage(), self.transform]) self.normalization_transform = transforms.Compose([transforms.ToPILImage(), self.dataset.TEST_TRANSFORM]) if hasattr( self.dataset, 'TEST_TRANSFORM') else transforms.Compose([transforms.ToPILImage(), transforms.ToTensor(), self.dataset.get_normalization_transform()]) diff --git a/models/xder_ce.py b/models/xder_ce.py index e40b2a8e..7916684a 100644 --- a/models/xder_ce.py +++ b/models/xder_ce.py @@ -29,9 +29,9 @@ def get_parser() -> ArgumentParser: parser.add_argument('--eta', type=float, default=0.1) parser.add_argument('--m', type=float, default=0.3) - parser.add_argument('--past_constraint', type=int, default=1, choices=[0,1], help='Enable past constraint') - parser.add_argument('--future_constraint', type=int, default=1, choices=[0,1], help='Enable future constraint') - parser.add_argument('--align_bn', type=int, default=0, choices=[0,1], help='Use BatchNorm alignment') + parser.add_argument('--past_constraint', type=int, default=1, choices=[0, 1], help='Enable past constraint') + parser.add_argument('--future_constraint', type=int, default=1, choices=[0, 1], help='Enable future constraint') + parser.add_argument('--align_bn', type=int, default=0, choices=[0, 1], help='Use BatchNorm alignment') return parser @@ -144,7 +144,7 @@ def observe(self, inputs, labels, not_aug_inputs, epoch=None): self.opt.zero_grad() - with bn_track_stats(self, self.args.align_bn==0 or self.current_task == 0): + with bn_track_stats(self, self.args.align_bn == 0 or self.current_task == 0): outputs = self.net(inputs) # Present head @@ -159,7 +159,7 @@ def observe(self, inputs, labels, not_aug_inputs, epoch=None): buf_inputs1 = torch.cat([buf_inputs1, inputs[:self.args.minibatch_size // self.current_task]]) buf_outputs1 = self.net(buf_inputs1) - + if self.args.align_bn: buf_inputs1 = buf_inputs1[:self.args.minibatch_size] buf_outputs1 = buf_outputs1[:self.args.minibatch_size] @@ -171,7 +171,7 @@ def observe(self, inputs, labels, not_aug_inputs, epoch=None): # Label Replay Loss (past heads) buf_idx2, buf_inputs2, buf_labels2, buf_logits2, buf_tl2 = self.buffer.get_data( self.args.minibatch_size, transform=self.transform, return_index=True, device=self.device) - with bn_track_stats(self, self.args.align_bn==0): + with bn_track_stats(self, self.args.align_bn == 0): buf_outputs2 = self.net(buf_inputs2) _, offset = self.dataset.get_offsets(self.current_task + (1 if self.current_task == 0 else 0)) diff --git a/models/xder_rpc.py b/models/xder_rpc.py index 6f09f15e..503798bf 100644 --- a/models/xder_rpc.py +++ b/models/xder_rpc.py @@ -73,7 +73,7 @@ def get_parser() -> ArgumentParser: parser.add_argument('--m', type=float, default=0.3) parser.add_argument('--clip_grad', type=none_or_float, default=None, metavar='NORM', help='Clip gradient norm (default: None, no clipping)') - parser.add_argument('--align_bn', type=int, default=0, choices=[0,1], help='Use BatchNorm alignment') + parser.add_argument('--align_bn', type=int, default=0, choices=[0, 1], help='Use BatchNorm alignment') parser.add_argument('--n_rpc_heads', type=int, help='N Heads for RPC') return parser @@ -190,7 +190,7 @@ def observe(self, inputs, labels, not_aug_inputs, epoch=None): self.opt.zero_grad() - with bn_track_stats(self, self.args.align_bn==0 or self.current_task == 0): + with bn_track_stats(self, self.args.align_bn == 0 or self.current_task == 0): outputs = self(inputs) # Present head @@ -205,7 +205,7 @@ def observe(self, inputs, labels, not_aug_inputs, epoch=None): buf_inputs1 = torch.cat([buf_inputs1, inputs[:self.args.minibatch_size // self.current_task]]) buf_outputs1 = self(buf_inputs1) - + if self.args.align_bn: buf_inputs1 = buf_inputs1[:self.args.minibatch_size] buf_outputs1 = buf_outputs1[:self.args.minibatch_size] @@ -217,7 +217,7 @@ def observe(self, inputs, labels, not_aug_inputs, epoch=None): # Label Replay Loss (past heads) buf_idx2, buf_inputs2, buf_labels2, buf_logits2, buf_tl2 = self.buffer.get_data( self.args.minibatch_size, transform=self.transform, return_index=True, device=self.device) - with bn_track_stats(self, self.args.align_bn==0): + with bn_track_stats(self, self.args.align_bn == 0): buf_outputs2 = self(buf_inputs2).float() buf_ce = self.loss(buf_outputs2[:, :self.n_past_classes], buf_labels2) diff --git a/optional-requirements.txt b/optional-requirements.txt deleted file mode 100644 index fafa0b15..00000000 --- a/optional-requirements.txt +++ /dev/null @@ -1 +0,0 @@ -setproctitle \ No newline at end of file diff --git a/requirements-optional.txt b/requirements-optional.txt new file mode 100644 index 00000000..0e9c2c08 --- /dev/null +++ b/requirements-optional.txt @@ -0,0 +1,6 @@ +googledrivedownloader==0.4 +onedrivedownloader==1.1.3 +pytest==7.4.2 +quadprog==0.1.11 +setproctitle==1.3.2 +wandb \ No newline at end of file diff --git a/scripts/prepare_grid.py b/scripts/prepare_grid.py index 287a6d17..2707c85b 100644 --- a/scripts/prepare_grid.py +++ b/scripts/prepare_grid.py @@ -10,10 +10,10 @@ grid_combinations = [ { - 'name':'experiment_name', + 'name': 'experiment_name', 'combos': { - 'lr': [0.01,0.3,0.05], - 'buffer_size':[500], + 'lr': [0.01, 0.3, 0.05], + 'buffer_size': [500], 'model': ['er'], 'dataset': ['seq-cifar10'] }, @@ -41,17 +41,16 @@ for k, v in zip(combos.keys(), c): if v is None: continue - if type(k) == tuple: - for i in range(len(k)): +if isinstance(k, if) for i in range(len(k)): ll += f" --{k[i]}={v[i]}" else: ll += f" --{k}={v}" - f.write(ll+'\n') + f.write(ll +'\n') all_configs.append(ll) clines += 1 - print(f"Total ({filenam}):",clines) + print(f"Total ({filenam}):", clines) print(f'{folder}list_all_grid.txt') clines = 0 @@ -60,5 +59,5 @@ f.write(ll + '\n') clines += 1 -print("Total (all):",clines) +print("Total (all):", clines) print('') diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_ccic.py b/tests/test_ccic.py index 9553575c..9b6ef49a 100644 --- a/tests/test_ccic.py +++ b/tests/test_ccic.py @@ -35,7 +35,7 @@ def test_ccic(dataset, label_perc): # log all outputs to file if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')): os.mkdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')) - sys.stdout = open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_ccic.{dataset}.log'), 'w', encoding='utf-8') + sys.stdout = open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_ccic.{dataset}.{label_perc}.log'), 'w', encoding='utf-8') sys.stderr = sys.stdout main() diff --git a/tests/test_checkpointing.py b/tests/test_checkpointing.py new file mode 100644 index 00000000..318e7a05 --- /dev/null +++ b/tests/test_checkpointing.py @@ -0,0 +1,178 @@ +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils.main import main +import pytest + + +@pytest.mark.parametrize('model', ['sgd', 'slca', 'l2p']) +def test_checkpointing_bufferfree(model): + N_TASKS = 5 # cifar10 + + # TEST CHECKPOINT SAVE + sys.argv = ['mammoth', + '--model', + model, + '--dataset', + 'seq-cifar10-224', + '--lr', + '1e-4', + '--n_epochs', + '1', + '--savecheck', + '1', + '--batch_size', + '4', + '--non_verbose', + '1', + '--num_workers', + '0', + '--seed', + '0', + '--debug_mode', + '1'] + + # log all outputs to file + if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')): + os.mkdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')) + sys.stdout = open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_checkpoint_save.{model}.log'), 'w', encoding='utf-8') + sys.stderr = sys.stdout + main() + + # read output file and search for the string 'Saving checkpoint into' + with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_checkpoint_save.{model}.log'), 'r', encoding='utf-8') as f: + lines = f.readlines() + ckpt_name = [line for line in lines if 'Saving checkpoint into' in line] + assert any(ckpt_name), f'Checkpoint not saved for model {model}' + + ckpt_name = ckpt_name[0].split('Saving checkpoint into')[-1].strip() + f'_{N_TASKS-1}.pt' + ckpt_path = os.path.join('checkpoints', ckpt_name) + + assert os.path.exists(ckpt_path), f'Checkpoint file {ckpt_path} not found' + + # TEST CHECKPOINT LOAD + sys.argv = ['mammoth', + '--model', + model, + '--dataset', + 'seq-cifar10-224', + '--lr', + '1e-4', + '--n_epochs', + '1', + '--loadcheck', + ckpt_path, + '--batch_size', + '4', + '--non_verbose', + '1', + '--num_workers', + '0', + '--seed', + '0', + '--debug_mode', + '1'] + + # log all outputs to file + if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')): + os.mkdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')) + sys.stdout = open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_checkpoint_load.{model}.log'), 'w', encoding='utf-8') + sys.stderr = sys.stdout + main() + + # REMOVE CHECKPOINT FILE + for i in range(N_TASKS): + c_path = ckpt_path.split(f'_{N_TASKS-1}.pt')[0] + f'_{i}.pt' + os.remove(c_path) + + +def test_checkpointing_replay(): + N_TASKS = 5 # cifar10 + + # TEST CHECKPOINT SAVE + sys.argv = ['mammoth', + '--model', + 'derpp', + '--dataset', + 'seq-cifar10', + '--alpha', + '0.1', + '--beta', + '0.1', + '--lr', + '1e-4', + '--n_epochs', + '1', + '--buffer_size', + '50', + '--savecheck', + '1', + '--batch_size', + '4', + '--non_verbose', + '1', + '--num_workers', + '0', + '--seed', + '0', + '--debug_mode', + '1'] + + # log all outputs to file + if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')): + os.mkdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')) + sys.stdout = open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_checkpoint_save.derpp.log'), 'w', encoding='utf-8') + sys.stderr = sys.stdout + main() + + # read output file and search for the string 'Saving checkpoint into' + with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_checkpoint_save.derpp.log'), 'r', encoding='utf-8') as f: + lines = f.readlines() + ckpt_name = [line for line in lines if 'Saving checkpoint into' in line] + assert any(ckpt_name), f'Checkpoint not saved for derpp' + + ckpt_name = ckpt_name[0].split('Saving checkpoint into')[-1].strip() + f'_{N_TASKS-1}.pt' + ckpt_path = os.path.join('checkpoints', ckpt_name) + + assert os.path.exists(ckpt_path), f'Checkpoint file {ckpt_path} not found' + + # TEST CHECKPOINT LOAD + sys.argv = ['mammoth', + '--model', + 'derpp', + '--dataset', + 'seq-cifar10', + '--alpha', + '0.1', + '--beta', + '0.1', + '--lr', + '1e-4', + '--n_epochs', + '1', + '--buffer_size', + '50', + '--loadcheck', + ckpt_path, + '--batch_size', + '4', + '--non_verbose', + '1', + '--num_workers', + '0', + '--seed', + '0', + '--debug_mode', + '1'] + + # log all outputs to file + if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')): + os.mkdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')) + sys.stdout = open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_checkpoint_load.derpp.log'), 'w', encoding='utf-8') + sys.stderr = sys.stdout + main() + + # REMOVE CHECKPOINT FILE + for i in range(N_TASKS): + c_path = ckpt_path.split(f'_{N_TASKS-1}.pt')[0] + f'_{i}.pt' + os.remove(c_path) diff --git a/tests/test_codaprompt.py b/tests/test_codaprompt.py new file mode 100644 index 00000000..fef858e4 --- /dev/null +++ b/tests/test_codaprompt.py @@ -0,0 +1,39 @@ +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils.main import main, parse_args +import pytest + + +@pytest.mark.parametrize('dataset', ['seq-cifar10-224', 'seq-imagenet-r']) +@pytest.mark.parametrize('code_optimization', [0, 1]) +def test_codaprompt(dataset, code_optimization): + sys.argv = ['mammoth', + '--model', + 'coda_prompt', + '--dataset', + dataset, + '--lr', + '1e-4', + '--n_epochs', + '1', + '--batch_size', + '2', + '--non_verbose', + '1', + '--num_workers', + '0', + '--seed', + '0', + '--code_optimization', + str(code_optimization), + '--debug_mode', + '1'] + + # log all outputs to file + if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')): + os.mkdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')) + sys.stdout = open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_codaprompt.{dataset}.O{code_optimization}.log'), 'w', encoding='utf-8') + sys.stderr = sys.stdout + + main() diff --git a/tests/test_code_optimization.py b/tests/test_code_optimization.py new file mode 100644 index 00000000..eb03a552 --- /dev/null +++ b/tests/test_code_optimization.py @@ -0,0 +1,71 @@ +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils.main import main +import pytest + + +@pytest.mark.parametrize('code_optimization', [0, 1, 2, 3]) +def test_code_optim_erace(code_optimization): + sys.argv = ['mammoth', + '--model', + 'er-ace', + '--buffer_size', + '50', + '--dataset', + 'seq-cifar10', + '--lr', + '1e-3', + '--n_epochs', + '1', + '--batch_size', + '4', + '--non_verbose', + '1', + '--num_workers', + '0', + '--seed', + '0', + '--debug_mode', + '1', + '--code_optimization', + str(code_optimization)] + + # log all outputs to file + if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')): + os.mkdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')) + sys.stdout = open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_code_optimization.O{code_optimization}.er-ace.seq-cifar10.log'), 'w', encoding='utf-8') + sys.stderr = sys.stdout + main() + + +@pytest.mark.parametrize('code_optimization', [0, 1, 2, 3]) +def test_code_optimization_slca(code_optimization): + sys.argv = ['mammoth', + '--model', + 'slca', + '--dataset', + 'seq-cifar10-224', + '--lr', + '1e-3', + '--n_epochs', + '1', + '--batch_size', + '4', + '--non_verbose', + '1', + '--num_workers', + '0', + '--seed', + '0', + '--debug_mode', + '1', + '--code_optimization', + str(code_optimization)] + + # log all outputs to file + if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')): + os.mkdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')) + sys.stdout = open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_code_optimization.{code_optimization}.slca.seq-cifar10-224.log'), 'w', encoding='utf-8') + sys.stderr = sys.stdout + main() diff --git a/tests/test_datasets.py b/tests/test_datasets.py index bf0bbc62..adc6bfb1 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -7,6 +7,7 @@ @pytest.mark.parametrize('dataset', ['seq-mnist', 'seq-cifar10', 'seq-cifar100', 'seq-tinyimg', 'rot-mnist', 'perm-mnist', 'mnist-360', 'seq-cifar100-224', + 'seq-cifar10-224', 'seq-cifar100-224-rs', 'seq-cifar100-224-rs', 'seq-tinyimg-r', 'seq-cub200', 'seq-imagenet-r']) def test_datasets(dataset): sys.argv = ['mammoth', @@ -29,10 +30,18 @@ def test_datasets(dataset): '--debug_mode', '1'] + # clean all downloaded datasets + dataset_paths = ['CUB200', 'CIFAR10', 'CIFAR100', 'MNIST', 'TINYIMG', 'imagenet-r'] + basepath = os.path.dirname(os.path.abspath(__file__)) + dt_dir = os.path.join(os.path.dirname(basepath), 'data') + for path in dataset_paths: + if os.path.exists(os.path.join(dt_dir, path)): + os.system(f'rm -rf {os.path.join(dt_dir, path)}') + # log all outputs to file - if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')): - os.mkdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')) - sys.stdout = open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_datasets.{dataset}.log'), 'w', encoding='utf-8') + if not os.path.exists(os.path.join(basepath, 'logs')): + os.mkdir(os.path.join(basepath, 'logs')) + sys.stdout = open(os.path.join(basepath, 'logs', f'test_datasets.{dataset}.log'), 'w', encoding='utf-8') sys.stderr = sys.stdout main() diff --git a/tests/test_dualprompt.py b/tests/test_dualprompt.py index 3a7a5757..a9963c70 100644 --- a/tests/test_dualprompt.py +++ b/tests/test_dualprompt.py @@ -5,7 +5,7 @@ import pytest -def test_l2p(): +def test_dualprompt(): sys.argv = ['mammoth', '--model', 'dualprompt', diff --git a/tests/test_twf.py b/tests/test_twf.py index a80c517c..04df890c 100644 --- a/tests/test_twf.py +++ b/tests/test_twf.py @@ -41,6 +41,7 @@ def test_twf_random_init(dataset, resize_maps): '0', '--seed', '0', + # '-O2', '--debug_mode', '1'] diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 00000000..3e2b055c --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,77 @@ +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils.main import main, parse_args +import pytest + +@pytest.mark.parametrize('validation', ['0.2','0','20']) +@pytest.mark.parametrize('validation_mode', ['complete','current']) +def test_validation_classil( validation, validation_mode): + sys.argv = ['mammoth', + '--model', + 'sgd', + '--dataset', + 'seq-cifar10', + '--lr', + '1e-4', + '--n_epochs', + '1', + '--validation', + validation, + '--validation_mode', + validation_mode, + '--batch_size', + '4', + '--non_verbose', + '1', + '--num_workers', + '0', + '--seed', + '0', + '--debug_mode', + '1'] + + # log all outputs to file + if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')): + os.mkdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')) + sys.stdout = open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_validation_classil.seq-cifar10.{validation}.{validation_mode}.log'), 'w', encoding='utf-8') + sys.stderr = sys.stdout + + main() + + +@pytest.mark.parametrize('dataset', ['mnist-360','perm-mnist']) +@pytest.mark.parametrize('validation', ['0.2','0','20']) +@pytest.mark.parametrize('validation_mode', ['complete']) +def test_validation_domainil(dataset, validation, validation_mode): + sys.argv = ['mammoth', + '--model', + 'sgd', + '--dataset', + dataset, + '--lr', + '1e-4', + '--n_epochs', + '1', + '--validation', + validation, + '--validation_mode', + validation_mode, + '--batch_size', + '4', + '--non_verbose', + '1', + '--num_workers', + '0', + '--seed', + '0', + '--debug_mode', + '1'] + + # log all outputs to file + if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')): + os.mkdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')) + sys.stdout = open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs', f'test_validation_domainil.{dataset}.{validation}.{validation_mode}.log'), 'w', encoding='utf-8') + sys.stderr = sys.stdout + + main() diff --git a/utils/args.py b/utils/args.py index 2ea10be1..c102f8eb 100644 --- a/utils/args.py +++ b/utils/args.py @@ -26,48 +26,60 @@ def add_experiment_args(parser: ArgumentParser) -> None: Returns: None """ - parser.add_argument('--dataset', type=str, required=True, - choices=get_dataset_names(), - help='Which dataset to perform experiments on.') - parser.add_argument('--model', type=custom_str_underscore, required=True, - help='Model name.', choices=list(get_all_models().keys())) - - parser.add_argument('--lr', type=float, required=True, - help='Learning rate.') - - parser.add_argument('--optimizer', type=str, default='sgd', - choices=ContinualModel.AVAIL_OPTIMS, - help='Optimizer.') - parser.add_argument('--optim_wd', type=float, default=0., - help='optimizer weight decay.') - parser.add_argument('--optim_mom', type=float, default=0., - help='optimizer momentum.') - parser.add_argument('--optim_nesterov', type=int, default=0, - help='optimizer nesterov momentum.') - - parser.add_argument('--lr_scheduler', type=str, help='Learning rate scheduler.') - parser.add_argument('--lr_milestones', type=int, nargs='+', default=[], - help='Learning rate scheduler milestones (used if `lr_scheduler=multisteplr`).') - parser.add_argument('--sched_multistep_lr_gamma', type=float, default=0.1, - help='Learning rate scheduler gamma (used if `lr_scheduler=multisteplr`).') - - parser.add_argument('--n_epochs', type=int, - help='Number of epochs.') - parser.add_argument('--batch_size', type=int, - help='Batch size.') - - parser.add_argument('--distributed', type=str, default='no', choices=['no', 'dp', 'ddp'], - help='Enable distributed training?') - parser.add_argument('--savecheck', action='store_true', help='Save checkpoint?') - parser.add_argument('--loadcheck', type=str, default=None, help='Path of the checkpoint to load (.pt file for the specific task)') - parser.add_argument('--ckpt_name', type=str, required=False, help='(optional) checkpoint save name.') - parser.add_argument('--start_from', type=int, default=None, help="Task to start from") - parser.add_argument('--stop_after', type=int, default=None, help="Task limit") - - parser.add_argument('--joint', type=int, choices=[0, 1], default=0, - help='Train model on Joint (single task)?') - parser.add_argument('--label_perc', type=float, default=1, - help='Percentage in (0-1] of labeled examples per task.') + exp_group = parser.add_argument_group('Experiment arguments', 'Arguments used to define the experiment settings.') + + exp_group.add_argument('--dataset', type=str, required=True, + choices=get_dataset_names(), + help='Which dataset to perform experiments on.') + exp_group.add_argument('--model', type=custom_str_underscore, required=True, + help='Model name.', choices=list(get_all_models().keys())) + exp_group.add_argument('--lr', type=float, required=True, help='Learning rate.') + exp_group.add_argument('--batch_size', type=int, help='Batch size.') + exp_group.add_argument('--label_perc', type=float, default=1, help='Percentage in (0-1] of labeled examples per task.') + exp_group.add_argument('--joint', type=int, choices=[0, 1], default=0, help='Train model on Joint (single task)?') + + validation_group = parser.add_argument_group('Validation and fitting arguments', 'Arguments used to define the validation strategy and the method used to fit the model.') + + validation_group.add_argument('--validation', type=float, help='Percentage of samples FOR EACH CLASS drawn from the training set to build the validation set.') + validation_group.add_argument('--validation_mode', type=str, choices=['complete', 'current'], default='current', + help='Mode used for validation. Must be used in combination with `validation` argument. Possible values:' + ' - `current`: uses only the current task for validation (default).' + ' - `complete`: uses data from both current and past tasks for validation.') + validation_group.add_argument('--fitting_mode', type=str, choices=['epochs', 'iters', 'time', 'early_stopping'], default='epochs', + help='Strategy used for fitting the model. Possible values:' + ' - `epochs`: fits the model for a fixed number of epochs (default). NOTE: this option is controlled by the `n_epochs` argument.' + ' - `iters`: fits the model for a fixed number of iterations. NOTE: this option is controlled by the `n_iters` argument.' + ' - `early_stopping`: fits the model until early stopping criteria are met. This option requires a validation set (see `validation` argument).' + ' The early stopping criteria are: if the validation loss does not decrease for `early_stopping_patience` epochs, the training stops.') + validation_group.add_argument('--early_stopping_patience', type=int, default=5, + help='Number of epochs to wait before stopping the training if the validation loss does not decrease. Used only if `fitting_mode=early_stopping`.') + validation_group.add_argument('--early_stopping_metric', type=str, default='loss', choices=['loss', 'accuracy'], + help='Metric used for early stopping. Used only if `fitting_mode=early_stopping`.') + validation_group.add_argument('--early_stopping_freq', type=int, default=1, + help='Frequency of validation evaluation. Used only if `fitting_mode=early_stopping`.') + validation_group.add_argument('--early_stopping_epsilon', type=float, default=1e-6, + help='Minimum improvement required to consider a new best model. Used only if `fitting_mode=early_stopping`.') + validation_group.add_argument('--n_epochs', type=int, + help='Number of epochs. Used only if `fitting_mode=epochs`.') + validation_group.add_argument('--n_iters', type=int, + help='Number of iterations. Used only if `fitting_mode=iters`.') + + opt_group = parser.add_argument_group('Optimizer and learning rate scheduler arguments', 'Arguments used to define the optimizer and the learning rate scheduler.') + + opt_group.add_argument('--optimizer', type=str, default='sgd', + choices=ContinualModel.AVAIL_OPTIMS, + help='Optimizer.') + opt_group.add_argument('--optim_wd', type=float, default=0., + help='optimizer weight decay.') + opt_group.add_argument('--optim_mom', type=float, default=0., + help='optimizer momentum.') + opt_group.add_argument('--optim_nesterov', type=int, default=0, + help='optimizer nesterov momentum.') + opt_group.add_argument('--lr_scheduler', type=str, help='Learning rate scheduler.') + opt_group.add_argument('--lr_milestones', type=int, nargs='+', default=[], + help='Learning rate scheduler milestones (used if `lr_scheduler=multisteplr`).') + opt_group.add_argument('--sched_multistep_lr_gamma', type=float, default=0.1, + help='Learning rate scheduler gamma (used if `lr_scheduler=multisteplr`).') def add_management_args(parser: ArgumentParser) -> None: @@ -80,32 +92,45 @@ def add_management_args(parser: ArgumentParser) -> None: Returns: None """ - parser.add_argument('--seed', type=int, default=None, - help='The random seed.') - parser.add_argument('--permute_classes', type=int, choices=[0, 1], default=0, - help='Permute classes before splitting tasks (applies seed before permute if seed is present)?') - parser.add_argument('--base_path', type=str, default="./data/", - help='The base path where to save datasets, logs, results.') - parser.add_argument('--notes', type=str, default=None, - help='Notes for this run.') - parser.add_argument('--wandb_name', type=str, default=None, - help='Wandb name for this run. Overrides the default name (`args.model`).') - - parser.add_argument('--non_verbose', default=0, choices=[0, 1], type=int, help='Make progress bars non verbose') - parser.add_argument('--disable_log', default=0, choices=[0, 1], type=int, help='Disable logging?') - parser.add_argument('--num_workers', type=int, default=None, help='Number of workers for the dataloaders (default=infer from number of cpus).') - - parser.add_argument('--validation', type=int, help='Percentage of validation set drawn from the training set.') - parser.add_argument('--enable_other_metrics', default=0, choices=[0, 1], type=int, - help='Enable computing additional metrics: forward and backward transfer.') - parser.add_argument('--debug_mode', type=int, default=0, choices=[0, 1], help='Run only a few forward steps per epoch') - parser.add_argument('--wandb_entity', type=str, help='Wandb entity') - parser.add_argument('--wandb_project', type=str, default='mammoth', help='Wandb project name') - - parser.add_argument('--eval_epochs', type=int, default=None, - help='Perform inference intra-task at every `eval_epochs`.') - parser.add_argument('--inference_only', action="store_true", - help='Perform inference only for each task (no training).') + mng_group = parser.add_argument_group('Management arguments', 'Generic arguments to manage the experiment reproducibility, logging, debugging, etc.') + + mng_group.add_argument('--seed', type=int, default=None, + help='The random seed. If not provided, a random seed will be used.') + mng_group.add_argument('--permute_classes', type=int, choices=[0, 1], default=0, + help='Permute classes before splitting into tasks? This applies the seed before permuting if the `seed` argument is present.') + mng_group.add_argument('--base_path', type=str, default="./data/", + help='The base path where to save datasets, logs, results.') + mng_group.add_argument('--notes', type=str, default=None, + help='Helper argument to include notes for this run. Example: distinguish between different versions of a model and allow separation of results') + mng_group.add_argument('--eval_epochs', type=int, default=None, + help='Perform inference on validation every `eval_epochs` epochs. If not provided, the model is evaluated ONLY at the end of each task.') + mng_group.add_argument('--non_verbose', default=0, choices=[0, 1], type=int, help='Make progress bars non verbose') + mng_group.add_argument('--disable_log', default=0, choices=[0, 1], type=int, help='Disable logging?') + mng_group.add_argument('--num_workers', type=int, default=None, help='Number of workers for the dataloaders (default=infer from number of cpus).') + mng_group.add_argument('--enable_other_metrics', default=0, choices=[0, 1], type=int, + help='Enable computing additional metrics: forward and backward transfer.') + mng_group.add_argument('--debug_mode', type=int, default=0, choices=[0, 1], help='Run only a few training steps per epoch. This also disables logging on wandb.') + mng_group.add_argument('--inference_only', default=0, choices=[0, 1], type=int, + help='Perform inference only for each task (no training).') + mng_group.add_argument('-O', '--code_optimization', type=int, default=0, choices=[0, 1, 2, 3], + help='Optimization level for the code.' + '0: no optimization.' + '1: Use TF32, if available.' + '2: Use BF16, if available.' + '3: Use BF16 and `torch.compile`. BEWARE: torch.compile may break your code if you change the model after the first run! Use with caution.') + mng_group.add_argument('--distributed', type=str, default='no', choices=['no', 'dp', 'ddp'], help='Enable distributed training?') + mng_group.add_argument('--savecheck', default=0, choices=[0, 1], type=int, help='Save checkpoint?') + mng_group.add_argument('--loadcheck', type=str, default=None, help='Path of the checkpoint to load (.pt file for the specific task)') + mng_group.add_argument('--ckpt_name', type=str, required=False, help='(optional) checkpoint save name.') + mng_group.add_argument('--start_from', type=int, default=None, help="Task to start from") + mng_group.add_argument('--stop_after', type=int, default=None, help="Task limit") + + wandb_group = parser.add_argument_group('Wandb arguments', 'Arguments to manage logging on Wandb.') + + wandb_group.add_argument('--wandb_name', type=str, default=None, + help='Wandb name for this run. Overrides the default name (`args.model`).') + wandb_group.add_argument('--wandb_entity', type=str, help='Wandb entity') + wandb_group.add_argument('--wandb_project', type=str, default='mammoth', help='Wandb project name') def add_rehearsal_args(parser: ArgumentParser) -> None: @@ -118,10 +143,12 @@ def add_rehearsal_args(parser: ArgumentParser) -> None: Returns: None """ - parser.add_argument('--buffer_size', type=int, required=True, - help='The size of the memory buffer.') - parser.add_argument('--minibatch_size', type=int, - help='The batch size of the memory buffer.') + group = parser.add_argument_group('Rehearsal arguments', 'Arguments shared by all rehearsal-based methods.') + + group.add_argument('--buffer_size', type=int, required=True, + help='The size of the memory buffer.') + group.add_argument('--minibatch_size', type=int, + help='The batch size of the memory buffer.') class _DocsArgs: @@ -149,6 +176,41 @@ def __str__(self): - Choices: {self.parse_choices() if self.choices is not None else ''}""" +class _DocArgsGroup: + """ + This class is used to generate the documentation of the arguments. + """ + + def __init__(self, group_name: str, group_desc: str, doc_args: _DocsArgs): + self.group_name = group_name + self.group_desc = group_desc + self.doc_args = doc_args + + def __str__(self): + args_str = '\n'.join([arg.__str__() for arg in self.doc_args]) + return f""".. rubric:: {self.group_name.capitalize()}\n\n*{self.group_desc}*\n\n{args_str}""" + + +def _parse_actions(actions: list, group_name: str, group_desc: str) -> _DocArgsGroup: + """ + Parses the actions of the parser. + + Args: + actions: the actions to parse + group_name: the name of the group + group_desc: the description of the group + + Returns: + an instance of _DocArgsGroup containing the parsed actions + """ + docs_args = [] + for action in actions: + if action.dest == 'help': + continue + docs_args.append(_DocsArgs(action.dest, action.type, action.choices, action.default, action.help)) + return _DocArgsGroup(group_name, group_desc, docs_args) + + if __name__ == '__main__': print("Generating documentation for the arguments...") os.chdir(mammoth_path) @@ -156,10 +218,8 @@ def __str__(self): add_experiment_args(parser) docs_args = [] - for action in parser._actions: - if action.dest == 'help': - continue - docs_args.append(_DocsArgs(action.dest, action.type, action.choices, action.default, action.help)) + for group in parser._action_groups[2:]: # first two groups are the positional and optional arguments + docs_args.append(_parse_actions(group._group_actions, group.title, group.description)) with open('docs/utils/args.rst', 'w') as f: f.write('.. _module-args:\n\n') @@ -172,10 +232,8 @@ def __str__(self): parser = ArgumentParser() add_management_args(parser) docs_args = [] - for action in parser._actions: - if action.dest == 'help': - continue - docs_args.append(_DocsArgs(action.dest, action.type, action.choices, action.default, action.help)) + for group in parser._action_groups[2:]: # first two groups are the positional and optional arguments + docs_args.append(_parse_actions(group._group_actions, group.title, group.description)) with open('docs/utils/args.rst', 'a') as f: f.write('.. rubric:: MANAGEMENT ARGS\n\n') diff --git a/utils/augmentations.py b/utils/augmentations.py index e7562335..efa0c763 100644 --- a/utils/augmentations.py +++ b/utils/augmentations.py @@ -7,6 +7,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import PIL import numpy as np import torch import torch.nn.functional as F @@ -30,6 +31,8 @@ def apply_transform(x: torch.Tensor, transform) -> torch.Tensor: """ if isinstance(transform, KorniaAugNoGrad): + if isinstance(x, PIL.Image.Image): + x = torch.as_tensor(np.array(x, copy=True)).permute((2, 0, 1)) return transform(x) else: return torch.stack([transform(xi) for xi in x.cpu()], dim=0).to(x.device) diff --git a/utils/buffer.py b/utils/buffer.py index 5e4bfe7f..211691bc 100644 --- a/utils/buffer.py +++ b/utils/buffer.py @@ -223,7 +223,7 @@ def get_data(self, size: int, transform: nn.Module = None, return_index=False, d def transform(x): return x selected_samples = self.examples[choice] if mask_task_out is None else self.examples[samples_mask][choice] - ret_tuple = (torch.stack([transform(ee) for ee in selected_samples.cpu()]).to(target_device),) + ret_tuple = (apply_transform(selected_samples, transform=transform).to(target_device),) for attr_str in self.attributes[1:]: if hasattr(self, attr_str): attr = getattr(self, attr_str) @@ -250,7 +250,7 @@ def get_data_by_index(self, indexes, transform: nn.Module = None, device=None) - if transform is None: def transform(x): return x - ret_tuple = (apply_transform(self.examples[:len(self)], transform=transform).to(target_device),) + ret_tuple = (apply_transform(self.examples[indexes], transform=transform).to(target_device),) for attr_str in self.attributes[1:]: if hasattr(self, attr_str): attr = getattr(self, attr_str).to(target_device) diff --git a/utils/conf.py b/utils/conf.py index ac2e102b..0f2fb681 100644 --- a/utils/conf.py +++ b/utils/conf.py @@ -29,6 +29,24 @@ def warn_once(*msg): print(msg, file=sys.stderr) +def get_alloc_memory_all_devices() -> list[int]: + """ + Returns the memory allocated on all the available devices. + """ + gpu_memory = [] + for i in range(torch.cuda.device_count()): + _ = torch.tensor([1]).to(i) + gpu_memory.append(torch.cuda.memory_allocated(i)) + if all(memory == 0 for memory in gpu_memory): + print("WARNING: some weird GPU memory issue. " + "Using trick from https://discuss.pytorch.org/t/torch-cuda-memory-allocated-returns-0-if-pytorch-no-cuda-memory-caching-1/188796") + for i in range(torch.cuda.device_count()): + torch.zeros(1).to(i) + free_memory, total_memory = torch.cuda.mem_get_info(i) + gpu_memory[i] = total_memory - free_memory + return gpu_memory + + def get_device() -> torch.device: """ Returns the least used GPU device if available else MPS or CPU. @@ -36,11 +54,8 @@ def get_device() -> torch.device: def _get_device(): # get least used gpu by used memory if torch.cuda.is_available() and torch.cuda.device_count() > 0: - gpu_memory = [] - for i in range(torch.cuda.device_count()): - gpu_memory.append(torch.cuda.memory_allocated(i)) + gpu_memory = get_alloc_memory_all_devices() device = torch.device(f'cuda:{np.argmin(gpu_memory)}') - print(f'Using device {device}') return device try: if torch.backends.mps.is_available() and torch.backends.mps.is_built(): diff --git a/utils/deprecated/continual_training.py b/utils/deprecated/continual_training.py index d8e8d88b..50647078 100644 --- a/utils/deprecated/continual_training.py +++ b/utils/deprecated/continual_training.py @@ -12,6 +12,7 @@ from models.utils.continual_model import ContinualModel from utils.loggers import Logger +from utils.stats import track_system_stats from utils.status import progress_bar try: @@ -57,31 +58,35 @@ def train(args: Namespace): backbone = dataset.get_backbone() loss = dataset.get_loss() model = get_model(args, backbone, loss, dataset.get_transform()) - model.net.to(model.device) if not args.disable_log: - logger = Logger(dataset.SETTING, dataset.NAME, model.NAME) - + logger = Logger(args, dataset.SETTING, dataset.NAME, model.NAME) if not args.nowand: assert wandb is not None, "Wandb not installed, please install it or run without wandb" wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=vars(args)) args.wandb_url = wandb.run.get_url() - model.net.train() - epoch, i = 0, 0 - while not dataset.train_over: - inputs, labels, not_aug_inputs = dataset.get_train_data() - inputs, labels = inputs.to(model.device), labels.to(model.device) - not_aug_inputs = not_aug_inputs.to(model.device) - loss = model.observe(inputs, labels, not_aug_inputs) - progress_bar(i, dataset.LENGTH // args.batch_size, epoch, 'C', loss) - i += 1 - - if model.NAME == 'joint_gcl': - model.end_task(dataset) - - acc = evaluate(model, dataset) - print('Accuracy:', acc) + model.net.to(model.device) + torch.cuda.empty_cache() + + with track_system_stats(logger) as system_tracker: + epoch, i = 0, 0 + model.net.train() + + while not dataset.train_over: + inputs, labels, not_aug_inputs = dataset.get_train_data() + inputs, labels = inputs.to(model.device), labels.to(model.device) + not_aug_inputs = not_aug_inputs.to(model.device) + loss = model.observe(inputs, labels, not_aug_inputs) + progress_bar(i, dataset.LENGTH // args.batch_size, epoch, 'C', loss) + system_tracker() + i += 1 + + if model.NAME == 'joint_gcl': + model.end_task(dataset) + + acc = evaluate(model, dataset) + print('Accuracy:', acc) if not args.disable_log: logger.log(acc) diff --git a/utils/kornia_utils.py b/utils/kornia_utils.py index 040ecd13..6c85a420 100644 --- a/utils/kornia_utils.py +++ b/utils/kornia_utils.py @@ -6,6 +6,39 @@ from kornia.augmentation.container.params import ParamItem +class KorniaMultiAug(kornia.augmentation.AugmentationSequential): + """ + A custom augmentation class that performs multiple Kornia augmentations. + + Args: + n_augs (int): The number of augmentations to apply. + aug_list (List[kornia.augmentation.AugmentationBase2D]): The list of augmentations to apply. + + Methods: + forward: Overrides the forward method to apply the transformation without gradient computation. + """ + + def __init__(self, n_augs: int, aug_list: List[kornia.augmentation.AugmentationBase2D]): + super().__init__(*aug_list) + self.n_augs = n_augs + + @torch.no_grad() + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overrides the forward method to apply the transformation without gradient computation. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The transformed tensor. + """ + original_shape = x.shape + x = super().forward(x.repeat(self.n_augs, 1, 1, 1)) + x = x.reshape(self.n_augs, *original_shape) + return x + + class KorniaAugNoGrad(kornia.augmentation.AugmentationSequential): """ A custom augmentation class that applies Kornia augmentations without gradient computation. diff --git a/utils/loggers.py b/utils/loggers.py index 7f0f70a8..d048bd75 100644 --- a/utils/loggers.py +++ b/utils/loggers.py @@ -24,7 +24,7 @@ def log_accs(args, logger, accs, t, setting, epoch=None, prefix="RESULT"): """ Logs the accuracy values and other metrics. - All metrics are prefixed with `RESULT_` to be logged on wandb. + All metrics are prefixed with `prefix` to be logged on wandb. Args: args: The arguments for logging. @@ -35,7 +35,7 @@ def log_accs(args, logger, accs, t, setting, epoch=None, prefix="RESULT"): epoch: The epoch number (optional). prefix: The prefix for the metrics (default="RESULT"). """ - mean_acc = print_mean_accuracy(accs, t + 1, setting, joint=args.joint, epoch=epoch) + mean_acc = print_mean_accuracy(accs, t + 1 if isinstance(t, (float, int)) else t, setting, joint=args.joint, epoch=epoch) if not args.disable_log: logger.log(mean_acc) @@ -72,23 +72,23 @@ def print_mean_accuracy(accs: np.ndarray, task_number: int, prefix = "Joint Accuracy" if epoch is None else f"Joint Accuracy (epoch {epoch})" if setting == 'domain-il' or setting == 'general-continual': mean_acc, _ = mean_acc - print('\n{}: \t [Domain-IL]: {} %'.format(prefix, round(mean_acc, 2), file=sys.stderr)) + print('{}: \t [Domain-IL]: {} %'.format(prefix, round(mean_acc, 2), file=sys.stderr)) print('\tRaw accuracy values: Domain-IL {}'.format(accs[0]), file=sys.stderr) else: mean_acc_class_il, mean_acc_task_il = mean_acc - print('\n{}: \t [Class-IL]: {} % \t [Task-IL]: {} %'.format(prefix, round( + print('{}: \t [Class-IL]: {} % \t [Task-IL]: {} %'.format(prefix, round( mean_acc_class_il, 2), round(mean_acc_task_il, 2)), file=sys.stderr) print('\tRaw accuracy values: Class-IL {} | Task-IL {}'.format(accs[0], accs[1]), file=sys.stderr) else: prefix = "Accuracy" if epoch is None else f"Accuracy (epoch {epoch})" if setting == 'domain-il' or setting == 'general-continual': mean_acc, _ = mean_acc - print('\n{} for {} task(s): [Domain-IL]: {} %'.format(prefix, - task_number, round(mean_acc, 2)), file=sys.stderr) + print('{} for {} task(s): [Domain-IL]: {} %'.format(prefix, + task_number, round(mean_acc, 2)), file=sys.stderr) print('\tRaw accuracy values: Domain-IL {}'.format(accs[0]), file=sys.stderr) else: mean_acc_class_il, mean_acc_task_il = mean_acc - print('\n{} for {} task(s): \t [Class-IL]: {} % \t [Task-IL]: {} %'.format(prefix, task_number, round( + print('{} for {} task(s): \t [Class-IL]: {} % \t [Task-IL]: {} %'.format(prefix, task_number, round( mean_acc_class_il, 2), round(mean_acc_task_il, 2)), file=sys.stderr) print('\tRaw accuracy values: Class-IL {} | Task-IL {}'.format(accs[0], accs[1]), file=sys.stderr) @@ -96,16 +96,18 @@ def print_mean_accuracy(accs: np.ndarray, task_number: int, class Logger: - def __init__(self, setting_str: str, dataset_str: str, + def __init__(self, args, setting_str: str, dataset_str: str, model_str: str) -> None: """ Initializes a Logger object. This will take track and log the accuracy values and other metrics in the default path (`data/results`). Args: + args: The args from the command line. setting_str: The setting of the benchmark. dataset_str: The dataset used. model_str: The model used. """ + self.args = args self.accs = [] self.fullaccs = [] if setting_str == 'class-il': @@ -120,6 +122,8 @@ def __init__(self, setting_str: str, dataset_str: str, self.bwt_mask_classes = None self.forgetting = None self.forgetting_mask_classes = None + self.cpu_res = [] + self.gpu_res = [] def dump(self): """ @@ -248,6 +252,23 @@ def log_fullacc(self, accs): self.fullaccs.append(acc_class_il) self.fullaccs_mask_classes.append(acc_task_il) + def log_system_stats(self, cpu_res, gpu_res): + """ + Logs the system stats. + Supported only if the `psutil` and `torch` libraries are installed. + + Args: + cpu_res: the CPU memory usage + gpu_res: the GPU memory usage + """ + if cpu_res is not None: + self.cpu_res.append(cpu_res) + if gpu_res is not None: + self.gpu_res.append(gpu_res) + + if not self.args.nowand: + wandb.log({'CPU_memory_usage': cpu_res, **{f'GPU_{i}_memory_usage': r for i, r in gpu_res.items()}}) + def write(self, args: Dict[str, Any]) -> None: """ Writes out the logged value along with its arguments in the default path (`data/results`). @@ -264,6 +285,9 @@ def write(self, args: Dict[str, Any]) -> None: for j, acc in enumerate(fa): wrargs['accuracy_' + str(j + 1) + '_task' + str(i + 1)] = acc + wrargs['cpu_memory_usage'] = self.cpu_res + wrargs['gpu_memory_usage'] = self.gpu_res + wrargs['forward_transfer'] = self.fwt wrargs['backward_transfer'] = self.bwt wrargs['forgetting'] = self.forgetting diff --git a/utils/main.py b/utils/main.py index 5d32469a..33415d7b 100644 --- a/utils/main.py +++ b/utils/main.py @@ -35,7 +35,7 @@ from utils import create_if_not_exists, custom_str_underscore from utils.args import add_management_args, add_experiment_args -from utils.conf import base_path +from utils.conf import base_path, get_device from utils.distributed import make_dp from utils.best_args import best_args from utils.conf import set_random_seed @@ -57,7 +57,8 @@ def parse_args(): args (argparse.Namespace): Parsed command line arguments. """ from models import get_all_models, get_model_class - from datasets import get_dataset_names, get_dataset_class + from datasets import get_dataset_names, get_dataset + # from datasets.utils import update_default_args parser = ArgumentParser(description='mammoth', allow_abbrev=False, add_help=False) parser.add_argument('--model', type=custom_str_underscore, help='Model name.', choices=list(get_all_models().keys())) @@ -105,14 +106,7 @@ def parse_args(): add_experiment_args(parser) args = parser.parse_args() - tmp_dset_class = get_dataset_class(args) - n_epochs = tmp_dset_class.get_epochs() - if args.n_epochs is None: - args.n_epochs = n_epochs - else: - if args.n_epochs != n_epochs: - print('Warning: n_epochs set to {} instead of {}.'.format(args.n_epochs, n_epochs), file=sys.stderr) - + get_dataset(args).update_default_args() args.model = models_dict[args.model] if args.lr_scheduler is not None: @@ -138,6 +132,10 @@ def parse_args(): assert 0 < args.label_perc <= 1, "label_perc must be in (0, 1]" + if args.validation is not None: + print(f"INFO: Using {args.validation}% of the training set as validation set.", file=sys.stderr) + print(f"INFO: Validation will be computed with mode `{args.validation_mode}`.", file=sys.stderr) + return args @@ -150,11 +148,20 @@ def main(args=None): if args is None: args = parse_args() + device = get_device() + args.device = device + # set base path base_path(args.base_path) - os.putenv("MKL_SERVICE_FORCE_INTEL", "1") - os.putenv("NPY_MKL_FORCE_INTEL", "1") + if args.code_optimization != 0: + torch.set_float32_matmul_precision('high' if args.code_optimization == 1 else 'medium') + print("INFO: code_optimization is set to", args.code_optimization, file=sys.stderr) + print(f"Using {torch.get_float32_matmul_precision()} precision for matmul.", file=sys.stderr) + + if args.code_optimization == 2: + if not torch.cuda.is_bf16_supported(): + raise NotImplementedError('BF16 is not supported on this machine.') # Add uuid, timestamp and hostname for logging args.conf_jobnum = str(uuid.uuid4()) @@ -162,8 +169,11 @@ def main(args=None): args.conf_host = socket.gethostname() dataset = get_dataset(args) - if args.n_epochs is None and isinstance(dataset, ContinualDataset): + if args.fitting_mode == 'epochs' and args.n_epochs is None and isinstance(dataset, ContinualDataset): args.n_epochs = dataset.get_epochs() + elif args.fitting_mode == 'iters' and args.n_iters is None and isinstance(dataset, ContinualDataset): + args.n_iters = dataset.get_iters() + if args.batch_size is None: args.batch_size = dataset.get_batch_size() if hasattr(importlib.import_module('models.' + args.model), 'Buffer') and (not hasattr(args, 'minibatch_size') or args.minibatch_size is None): @@ -171,9 +181,30 @@ def main(args=None): else: args.minibatch_size = args.batch_size + if args.validation: + if args.validation_mode == 'current': + assert dataset.SETTING in ['class-il', 'task-il'], "`current` validation modes is only supported for class-il and task-il settings (requires a task division)." + backbone = dataset.get_backbone() + if args.code_optimization == 3: + # check if the model is compatible with torch.compile + # from https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html + if torch.cuda.get_device_capability()[0] >= 7 and os.name != 'nt': + print("================ Compiling model with torch.compile ================") + print("WARNING: `torch.compile` may break your code if you change the model after the first run!") + print("This includes adding classifiers for new tasks, changing the backbone, etc.") + print("ALSO: some models CHANGE the backbone during initialization. Remember to call `torch.compile` again after that.") + print("====================================================================") + backbone = torch.compile(backbone) + else: + if torch.cuda.get_device_capability()[0] < 7: + raise NotImplementedError('torch.compile is not supported on this machine.') + else: + raise Exception(f"torch.compile is not supported on Windows. Check https://github.com/pytorch/pytorch/issues/90768 for updates.") + loss = dataset.get_loss() model = get_model(args, backbone, loss, dataset.get_transform()) + # model = torch.compile(model) if args.distributed == 'dp': if args.batch_size < torch.cuda.device_count(): diff --git a/utils/ring_buffer.py b/utils/ring_buffer.py index 310dc267..baf2b400 100644 --- a/utils/ring_buffer.py +++ b/utils/ring_buffer.py @@ -13,6 +13,8 @@ import torch from torchvision import transforms +from utils.augmentations import apply_transform + def ring(num_seen_examples: int, buffer_portion_size: int, task: int) -> int: return num_seen_examples % buffer_portion_size + task * buffer_portion_size @@ -94,11 +96,11 @@ def get_data(self, size: int, transform: transforms = None, device=None) -> Tupl if size > populated_portion_length: size = populated_portion_length - choice = torch.from_numpy(np.random.choice(populated_portion_length, size=size, replace=False)).to(self.device, dtype=torch.long) + choice = torch.from_numpy(np.random.choice(populated_portion_length, size=size, replace=False)) if transform is None: def transform(x): return x - ret_tuple = (torch.stack([transform(ee) - for ee in self.examples[choice].cpu()]).to(target_device),) + + ret_tuple = (apply_transform(self.examples[choice], transform=transform).to(target_device),) for attr_str in self.attributes[1:]: if hasattr(self, attr_str): attr = getattr(self, attr_str) @@ -129,8 +131,7 @@ def get_all_data(self, transform: transforms = None, device=None) -> Tuple: target_device = self.device if device is None else device if transform is None: def transform(x): return x - ret_tuple = (torch.stack([transform(ee) - for ee in self.examples.cpu()]).to(target_device),) + ret_tuple = (apply_transform(self.examples[choice], transform=transform).to(target_device),) for attr_str in self.attributes[1:]: if hasattr(self, attr_str): attr = getattr(self, attr_str) diff --git a/utils/stats.py b/utils/stats.py new file mode 100644 index 00000000..d3d36a68 --- /dev/null +++ b/utils/stats.py @@ -0,0 +1,163 @@ +try: + from resource import getrusage, RUSAGE_CHILDREN, RUSAGE_SELF + + def get_memory_mb(): + """ + Get the memory usage of the current process and its children. + + Returns: + dict: A dictionary containing the memory usage of the current process and its children. + + The dictionary has the following + keys: + - self: The memory usage of the current process. + - children: The memory usage of the children of the current process. + - total: The total memory usage of the current process and its children. + """ + res = { + "self": getrusage(RUSAGE_SELF).ru_maxrss / 1024, + "children": getrusage(RUSAGE_CHILDREN).ru_maxrss / 1024, + "total": getrusage(RUSAGE_SELF).ru_maxrss / 1024 + getrusage(RUSAGE_CHILDREN).ru_maxrss / 1024 + } + return res +except BaseException: + get_memory_mb = None + +try: + import torch + + if torch.cuda.is_available(): + from utils.conf import get_alloc_memory_all_devices + + def get_memory_gpu_mb(): + """ + Get the memory usage of all GPUs in MB. + """ + + return [d / 1024 for d in get_alloc_memory_all_devices()] + else: + get_memory_gpu_mb = None +except BaseException: + get_memory_gpu_mb = None + +from utils.loggers import Logger + + +class track_system_stats: + """ + A context manager that tracks the memory usage of the system. + Tracks both CPU and GPU memory usage if available. + + Usage: + with track_system_stats() as t: + for i in range(100): + ... # Do something + t() + + cpu_res, gpu_res = t.cpu_res, t.gpu_res + + Args: + logger (Logger): external logger. + disabled (bool): If True, the context manager will not track the memory usage. + """ + + def get_stats(self): + """ + Get the memory usage of the system. + + Returns: + tuple: (cpu_res, gpu_res) where cpu_res is the memory usage of the CPU and gpu_res is the memory usage of the GPU. + """ + cpu_res = None + if get_memory_mb is not None: + cpu_res = get_memory_mb()['total'] + + gpu_res = None + if get_memory_gpu_mb is not None: + gpu_res = get_memory_gpu_mb() + + return cpu_res, gpu_res + + def __init__(self, logger: Logger = None, disabled=False): + self.logger = logger + self.disabled = disabled + self._it = 0 + + def __enter__(self): + if self.disabled: + return self + self.initial_cpu_res, self.initial_gpu_res = self.get_stats() + self.initial_gpu_res = {g: g_res for g, g_res in enumerate(self.initial_gpu_res)} + + self.avg_gpu_res = self.initial_gpu_res + self.avg_cpu_res = self.initial_cpu_res + + self.max_cpu_res = self.initial_cpu_res + self.max_gpu_res = self.initial_gpu_res + + if self.initial_cpu_res is None and self.initial_gpu_res is None: + self.disabled = True + + if self.logger is not None: + self.logger.log_system_stats(self.initial_cpu_res, self.initial_gpu_res) + + return self + + def __call__(self): + if self.disabled: + return + + cpu_res, gpu_res = self.get_stats() + self.update_stats(cpu_res, gpu_res) + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.disabled: + return + + cpu_res, gpu_res = self.get_stats() + self.update_stats(cpu_res, gpu_res) + + def update_stats(self, cpu_res, gpu_res): + """ + Update the memory usage statistics. + + Args: + cpu_res (float): The memory usage of the CPU. + gpu_res (list): The memory usage of the GPUs. + """ + if self.disabled: + return + + self._it += 1 + + alpha = 1 / self._it + self.avg_cpu_res = self.avg_cpu_res + alpha * (cpu_res - self.avg_cpu_res) + self.avg_gpu_res = {g: (g_res + alpha * (g_res - self.avg_gpu_res[g])) for g, g_res in enumerate(gpu_res)} + + self.max_cpu_res = max(self.max_cpu_res, cpu_res) + self.max_gpu_res = {g: max(self.max_gpu_res[g], g_res) for g, g_res in enumerate(gpu_res)} + + if self.logger is not None: + self.logger.log_system_stats(cpu_res, gpu_res) + + def print_stats(self): + """ + Print the memory usage statistics. + """ + + cpu_res, gpu_res = self.get_stats() + + # Print initial, average, final, and max memory usage + print("System stats:") + if cpu_res is not None: + print(f"\tInitial CPU memory usage: {self.initial_cpu_res:.2f} MB", flush=True) + print(f"\tAverage CPU memory usage: {self.avg_cpu_res:.2f} MB", flush=True) + print(f"\tFinal CPU memory usage: {cpu_res:.2f} MB", flush=True) + print(f"\tMax CPU memory usage: {self.max_cpu_res:.2f} MB", flush=True) + + if gpu_res is not None: + for gpu_id, g_res in enumerate(gpu_res): + print(f"\tInitial GPU {gpu_id} memory usage: {self.initial_gpu_res[gpu_id]:.2f} MB", flush=True) + print(f"\tAverage GPU {gpu_id} memory usage: {self.avg_gpu_res[gpu_id]:.2f} MB", flush=True) + print(f"\tFinal GPU {gpu_id} memory usage: {g_res:.2f} MB", flush=True) + print(f"\tMax GPU {gpu_id} memory usage: {self.max_gpu_res[gpu_id]:.2f} MB", flush=True) diff --git a/utils/status.py b/utils/status.py index 4c24c15c..a799895c 100644 --- a/utils/status.py +++ b/utils/status.py @@ -7,21 +7,52 @@ from datetime import datetime from time import time from typing import Union +import shutil + + +def padded_print(string: str, max_width: int, **kwargs) -> None: + """ + Prints a string with blank spaces to reach the max_width. + + Args: + string: the string to print + max_width: the maximum width of the string + """ + pad_len = max(0, max_width - len(string)) + print(string + ' ' * pad_len, **kwargs) class ProgressBar: - def __init__(self, joint=False, verbose=True): + def __init__(self, joint=False, verbose=True, update_every=1): """ Initializes a ProgressBar object. Args: joint: a boolean indicating whether the progress bar is for a joint task verbose: a boolean indicating whether to display the progress bar + update_every: the number of iterations after which the progress bar is updated """ self.joint = joint - self.old_time = 0 - self.running_sum = 0 + self.update_every = update_every self.verbose = verbose + self.old_time = None + + self.reset() + + assert self.update_every > 0 + + def reset(self) -> None: + """ + Resets the progress bar. + """ + if self.old_time is not None: + max_width = shutil.get_terminal_size((80, 20)).columns + padded_print(f'\n\t- Took: {round(self.running_sum, 2)} s', max_width=max_width, file=sys.stderr, flush=True) + + self.old_time = time() + self.running_sum = 0 + self.current_task_iter = 0 + self.last_task = 0 def prog(self, i: int, max_iter: int, epoch: Union[int, str], task_number: int, loss: float) -> None: @@ -35,47 +66,59 @@ def prog(self, i: int, max_iter: int, epoch: Union[int, str], task_number: the task index loss: the current value of the loss function """ + max_width = shutil.get_terminal_size((80, 20)).columns if not self.verbose: if i == 0: if self.joint: - print('[ {} ] Joint | epoch {}\n'.format( + padded_print('[ {} ] Joint | epoch {}\n'.format( datetime.now().strftime("%m-%d | %H:%M"), epoch - ), file=sys.stderr, end='', flush=True) + ), max_width=max_width, file=sys.stderr, end='', flush=True) else: - print('[ {} ] Task {} | epoch {}\n'.format( + padded_print('[ {} ] Task {} | epoch {}\n'.format( datetime.now().strftime("%m-%d | %H:%M"), task_number + 1 if isinstance(task_number, int) else task_number, epoch - ), file=sys.stderr, end='', flush=True) + ), max_width=max_width, file=sys.stderr, end='', flush=True) else: return - if i == 0: - self.old_time = time() - self.running_sum = 0 - else: - self.running_sum = self.running_sum + (time() - self.old_time) + 1e-8 - self.old_time = time() - if i: # not (i + 1) % 10 or (i + 1) == max_iter: + + timediff = time() - self.old_time + self.running_sum = self.running_sum + timediff + 1e-8 + + # Print the progress bar every update_every iterations + if (i and i % self.update_every == 0) or (max_iter is not None and i == max_iter - 1): progress = min(float((i + 1) / max_iter), 1) if max_iter else 0 progress_bar = ('█' * int(50 * progress)) + ('┈' * (50 - int(50 * progress))) if max_iter else '~N/A~' if self.joint: - print('\r[ {} ] Joint | epoch {}: |{}| {} ep/h | loss: {} |'.format( + padded_print('\r[ {} ] Joint | epoch {} | iter {}: |{}| {} ep/h | loss: {} | Time: {} ms/it'.format( datetime.now().strftime("%m-%d | %H:%M"), epoch, + self.current_task_iter + 1, progress_bar, round(3600 / (self.running_sum / i * max_iter), 2) if max_iter else 'N/A', - round(loss, 8) - ), file=sys.stderr, end='', flush=True) + round(loss, 8), + round(1000 * timediff / self.update_every, 2) + ), max_width=max_width, file=sys.stderr, end='', flush=True) else: - print('\r[ {} ] Task {} | epoch {}: |{}| {} ep/h | loss: {} |'.format( + padded_print('\r[ {} ] Task {} | epoch {} | iter {}: |{}| {} ep/h | loss: {} | Time: {} ms/it'.format( datetime.now().strftime("%m-%d | %H:%M"), task_number + 1 if isinstance(task_number, int) else task_number, epoch, + self.current_task_iter + 1, progress_bar, round(3600 / (self.running_sum / i * max_iter), 2) if max_iter else 'N/A', - round(loss, 8) - ), file=sys.stderr, end='', flush=True) + round(loss, 8), + round(1000 * timediff / self.update_every, 2) + ), max_width=max_width, file=sys.stderr, end='', flush=True) + + self.current_task_iter += 1 + + # def __del__(self): + # max_width = shutil.get_terminal_size((80, 20)).columns + # # if self.verbose: + # # print('\n', file=sys.stderr, flush=True) + # padded_print('\tLast task took: {} s'.format(round(self.running_sum, 2)), max_width=max_width, file=sys.stderr, flush=True) def progress_bar(i: int, max_iter: int, epoch: Union[int, str], diff --git a/utils/training.py b/utils/training.py index 219c8d96..46815172 100644 --- a/utils/training.py +++ b/utils/training.py @@ -7,7 +7,7 @@ import math import sys from argparse import Namespace -from typing import Tuple +from typing import Iterable, Tuple import torch from datasets import get_dataset @@ -18,7 +18,9 @@ from utils import random_id from utils.checkpoints import mammoth_load_checkpoint from utils.loggers import * +from utils.stats import track_system_stats from utils.status import ProgressBar +import time try: import wandb @@ -43,7 +45,7 @@ def mask_classes(outputs: torch.Tensor, dataset: ContinualDataset, k: int) -> No @torch.no_grad() -def evaluate(model: ContinualModel, dataset: ContinualDataset, last=False) -> Tuple[list, list]: +def evaluate(model: ContinualModel, dataset: ContinualDataset, last=False, return_loss=False) -> Tuple[list, list]: """ Evaluates the accuracy of the model for each past task. @@ -52,14 +54,18 @@ def evaluate(model: ContinualModel, dataset: ContinualDataset, last=False) -> Tu Args: model: the model to be evaluated dataset: the continual dataset at hand + last: a boolean indicating whether to evaluate only the last task + return_loss: a boolean indicating whether to return the loss in addition to the accuracy Returns: - a tuple of lists, containing the class-il and task-il accuracy for each task + a tuple of lists, containing the class-il and task-il accuracy for each task. If return_loss is True, the loss is also returned as a third element. """ status = model.net.training model.net.eval() accs, accs_mask_classes = [], [] n_classes = dataset.get_offsets()[1] + loss_fn = dataset.get_loss() + avg_loss = 0 for k, test_loader in enumerate(dataset.test_loaders): if last and k < len(dataset.test_loaders) - 1: continue @@ -80,6 +86,10 @@ def evaluate(model: ContinualModel, dataset: ContinualDataset, last=False) -> Tu else: outputs = model(inputs) + if return_loss: + loss = loss_fn(outputs, labels) + avg_loss += loss.item() + _, pred = torch.max(outputs[:, :n_classes].data, 1) correct += torch.sum(pred == labels).item() total += labels.shape[0] @@ -95,6 +105,8 @@ def evaluate(model: ContinualModel, dataset: ContinualDataset, last=False) -> Tu accs_mask_classes.append(correct_mask_classes / total * 100) model.net.train(status) + if return_loss: + return accs, accs_mask_classes, avg_loss / total return accs, accs_mask_classes @@ -114,6 +126,71 @@ def initialize_wandb(args: Namespace) -> None: args.wandb_url = wandb.run.get_url() +def train_single_epoch(model: ContinualModel, + train_loader: Iterable, + progress_bar: ProgressBar, + args: Namespace, + epoch: int, + current_task: int, + system_tracker=None, + data_len=None, + scheduler=None) -> int: + """ + Trains the model for a single epoch. + + Args: + model: the model to be trained + train_loader: the data loader for the training set + progress_bar: the progress bar for the current epoch + args: the arguments from the command line + epoch: the current epoch + current_task: the current task index + system_tracker: the system tracker to monitor the system stats + data_len: the length of the training data loader. If None, the progress bar will not show the training percentage + scheduler: the scheduler for the current epoch + + Returns: + the number of iterations performed in the current epoch + """ + train_iter = iter(train_loader) + + i = 0 + while True: + try: + data = next(train_iter) + except StopIteration: + break + if args.debug_mode and i > model.get_debug_iters(): + break + if args.fitting_mode == 'iters' and progress_bar.current_task_iter >= model.args.n_iters: + break + + if hasattr(train_loader.dataset, 'logits'): + inputs, labels, not_aug_inputs, logits = data + inputs = inputs.to(model.device) + labels = labels.to(model.device, dtype=torch.long) + not_aug_inputs = not_aug_inputs.to(model.device) + logits = logits.to(model.device) + loss = model.meta_observe(inputs, labels, not_aug_inputs, logits, epoch=epoch) + else: + inputs, labels, not_aug_inputs = data + inputs, labels = inputs.to(model.device), labels.to(model.device, dtype=torch.long) + not_aug_inputs = not_aug_inputs.to(model.device) + loss = model.meta_observe(inputs, labels, not_aug_inputs, epoch=epoch) + assert not math.isnan(loss) + + if args.code_optimization == 0: + torch.cuda.synchronize() + progress_bar.prog(i, data_len, epoch, current_task, loss) + system_tracker() + i += 1 + + if scheduler is not None: + scheduler.step() + + return i + + def train(model: ContinualModel, dataset: ContinualDataset, args: Namespace) -> None: """ @@ -129,131 +206,156 @@ def train(model: ContinualModel, dataset: ContinualDataset, if not args.nowand: initialize_wandb(args) + if not args.disable_log: + logger = Logger(args, dataset.SETTING, dataset.NAME, model.NAME) + model.net.to(model.device) - results, results_mask_classes = [], [] + torch.cuda.empty_cache() - if not args.disable_log: - logger = Logger(dataset.SETTING, dataset.NAME, model.NAME) + with track_system_stats(logger) as system_tracker: + results, results_mask_classes = [], [] - if args.start_from is not None: - for i in range(args.start_from): - train_loader, _ = dataset.get_data_loaders() - model.meta_begin_task(dataset) - model.meta_end_task(dataset) + if args.start_from is not None: + for i in range(args.start_from): + train_loader, _ = dataset.get_data_loaders() + model.meta_begin_task(dataset) + model.meta_end_task(dataset) + + if args.loadcheck is not None: + model, past_res = mammoth_load_checkpoint(args, model) - if args.loadcheck is not None: - model, past_res = mammoth_load_checkpoint(args, model) + if not args.disable_log and past_res is not None: + (results, results_mask_classes, csvdump) = past_res + logger.load(csvdump) - if not args.disable_log and past_res is not None: - (results, results_mask_classes, csvdump) = past_res - logger.load(csvdump) + print('Checkpoint Loaded!') - print('Checkpoint Loaded!') + progress_bar = ProgressBar(joint=args.joint, verbose=not args.non_verbose) - progress_bar = ProgressBar(joint=args.joint, verbose=not args.non_verbose) + if args.enable_other_metrics: + dataset_copy = get_dataset(args) + for t in range(dataset.N_TASKS): + model.net.train() + _, _ = dataset_copy.get_data_loaders() + if model.NAME != 'icarl' and model.NAME != 'pnn': + random_results_class, random_results_task = evaluate(model, dataset_copy) - if args.enable_other_metrics: - dataset_copy = get_dataset(args) - for t in range(dataset.N_TASKS): + print(file=sys.stderr) + start_task = 0 if args.start_from is None else args.start_from + end_task = dataset.N_TASKS if args.stop_after is None else args.stop_after + + torch.cuda.empty_cache() + for t in range(start_task, end_task): model.net.train() - _, _ = dataset_copy.get_data_loaders() - if model.NAME != 'icarl' and model.NAME != 'pnn': - random_results_class, random_results_task = evaluate(model, dataset_copy) + train_loader, test_loader = dataset.get_data_loaders() + model.meta_begin_task(dataset) - print(file=sys.stderr) - start_task = 0 if args.start_from is None else args.start_from - end_task = dataset.N_TASKS if args.stop_after is None else args.stop_after + if not args.inference_only: + if t and args.enable_other_metrics: + accs = evaluate(model, dataset, last=True) + results[t - 1] = results[t - 1] + accs[0] + if dataset.SETTING == 'class-il': + results_mask_classes[t - 1] = results_mask_classes[t - 1] + accs[1] - torch.cuda.empty_cache() - for t in range(start_task, end_task): - model.net.train() - train_loader, test_loader = dataset.get_data_loaders() - model.meta_begin_task(dataset) - - if not args.inference_only: - if t and args.enable_other_metrics: - accs = evaluate(model, dataset, last=True) - results[t - 1] = results[t - 1] + accs[0] - if dataset.SETTING == 'class-il': - results_mask_classes[t - 1] = results_mask_classes[t - 1] + accs[1] - - scheduler = dataset.get_scheduler(model, args) if not hasattr(model, 'scheduler') else model.scheduler - for epoch in range(model.args.n_epochs): - train_iter = iter(train_loader) - data_len = None - if not isinstance(dataset, GCLDataset): - data_len = len(train_loader) - i = 0 + scheduler = dataset.get_scheduler(model, args) if not hasattr(model, 'scheduler') else model.scheduler + + epoch = 0 + best_ea_metric = None + best_ea_model = None + cur_stopping_patience = args.early_stopping_patience while True: - try: - data = next(train_iter) - except StopIteration: + data_len = None + if not isinstance(dataset, GCLDataset): + data_len = len(train_loader) + + train_single_epoch(model, train_loader, progress_bar, args, current_task=t, epoch=epoch, + system_tracker=system_tracker, data_len=data_len, scheduler=scheduler) + + epoch += 1 + if args.fitting_mode == 'epochs' and epoch >= model.args.n_epochs: break - if args.debug_mode and i > model.get_debug_iters(): + elif args.fitting_mode == 'iters' and progress_bar.current_task_iter >= model.args.n_iters: break - if hasattr(dataset.train_loader.dataset, 'logits'): - inputs, labels, not_aug_inputs, logits = data - inputs = inputs.to(model.device) - labels = labels.to(model.device, dtype=torch.long) - not_aug_inputs = not_aug_inputs.to(model.device) - logits = logits.to(model.device) - loss = model.meta_observe(inputs, labels, not_aug_inputs, logits, epoch=epoch) - else: - inputs, labels, not_aug_inputs = data - inputs, labels = inputs.to(model.device), labels.to(model.device, dtype=torch.long) - not_aug_inputs = not_aug_inputs.to(model.device) - loss = model.meta_observe(inputs, labels, not_aug_inputs, epoch=epoch) - assert not math.isnan(loss) - progress_bar.prog(i, data_len, epoch, t, loss) - i += 1 - - if scheduler is not None: - scheduler.step() - - if args.eval_epochs is not None and epoch % args.eval_epochs == 0 and epoch < model.args.n_epochs - 1: - epoch_accs = evaluate(model, dataset) - - log_accs(args, logger, epoch_accs, t, dataset.SETTING, epoch=epoch) - - model.meta_end_task(dataset) - - accs = evaluate(model, dataset) - results.append(accs[0]) - results_mask_classes.append(accs[1]) - - log_accs(args, logger, accs, t, dataset.SETTING) - - if args.savecheck: - save_obj = { - 'model': model.state_dict(), - 'args': args, - 'results': [results, results_mask_classes, logger.dump()], - 'optimizer': model.opt.state_dict() if hasattr(model, 'opt') else None, - 'scheduler': scheduler.state_dict() if scheduler is not None else None, - } - if 'buffer_size' in model.args: - save_obj['buffer'] = deepcopy(model.buffer).to('cpu') - - # Saving model checkpoint - checkpoint_name = f'checkpoints/{args.ckpt_name}_joint.pt' if args.joint else f'checkpoints/{args.ckpt_name}_{t}.pt' - torch.save(save_obj, checkpoint_name) - - if args.validation: - del dataset - args.validation = None - - final_dataset = get_dataset(args) - for _ in range(final_dataset.N_TASKS): - final_dataset.get_data_loaders() - accs = evaluate(model, final_dataset) - log_accs(args, logger, accs, t, final_dataset.SETTING, prefix="FINAL") - - if not args.disable_log and args.enable_other_metrics: - logger.add_bwt(results, results_mask_classes) - logger.add_forgetting(results, results_mask_classes) - if model.NAME != 'icarl' and model.NAME != 'pnn': - logger.add_fwt(results, random_results_class, - results_mask_classes, random_results_task) + elif args.fitting_mode == 'early_stopping' and epoch % args.early_stopping_freq == 0 and epoch > 0: + epoch_accs, _, epoch_loss = evaluate(model, dataset, return_loss=True, last=True) + + if args.early_stopping_metric == 'accuracy': + ea_metric = np.mean(epoch_accs) # Higher accuracy is better + elif args.early_stopping_metric == 'loss': + ea_metric = -epoch_loss # Lower loss is better + else: + raise ValueError(f'Unknown early stopping metric {args.early_stopping_metric}') + + # Higher accuracy is better + if best_ea_metric is not None and ea_metric - best_ea_metric < args.early_stopping_epsilon: + cur_stopping_patience -= args.early_stopping_freq + if cur_stopping_patience <= 0: + print(f"\nEarly stopping at epoch {epoch} with metric {abs(ea_metric)}", file=sys.stderr) + model.load_state_dict({k: v.to(model.device) for k, v in best_ea_model.items()}) + break + print(f"\nNo improvement at epoch {epoch} (best {abs(best_ea_metric)} | current {abs(ea_metric)}). " + f"Waiting for {cur_stopping_patience} epochs to stop.", file=sys.stderr) + else: + print(f"\nFound better model with metric {abs(ea_metric)} at epoch {epoch}. " + f"Previous value was {abs(best_ea_metric) if best_ea_metric is not None else 'None'}", file=sys.stderr) + best_ea_metric = ea_metric + best_ea_model = deepcopy({k: v.cpu() for k, v in model.state_dict().items()}) + cur_stopping_patience = args.early_stopping_patience + + if args.eval_epochs is not None and (epoch > 0 or args.eval_epochs == 1) and epoch % args.eval_epochs == 0 and epoch < model.args.n_epochs: + epoch_accs = evaluate(model, dataset) + + log_accs(args, logger, epoch_accs, t, dataset.SETTING, epoch=epoch) + + progress_bar.reset() + + model.meta_end_task(dataset) + + accs = evaluate(model, dataset) + results.append(accs[0]) + results_mask_classes.append(accs[1]) + + log_accs(args, logger, accs, t, dataset.SETTING) + + if args.savecheck: + save_obj = { + 'model': model.state_dict(), + 'args': args, + 'results': [results, results_mask_classes, logger.dump()], + 'optimizer': model.opt.state_dict() if hasattr(model, 'opt') else None, + 'scheduler': scheduler.state_dict() if scheduler is not None else None, + } + if 'buffer_size' in model.args: + save_obj['buffer'] = deepcopy(model.buffer).to('cpu') + + # Saving model checkpoint + checkpoint_name = f'checkpoints/{args.ckpt_name}_joint.pt' if args.joint else f'checkpoints/{args.ckpt_name}_{t}.pt' + torch.save(save_obj, checkpoint_name) + + del progress_bar + + if args.validation: + # Final evaluation on the real test set + print("Starting final evaluation on the real test set...", file=sys.stderr) + del dataset + args.validation = None + args.validation_mode = 'current' + + final_dataset = get_dataset(args) + for _ in range(final_dataset.N_TASKS): + final_dataset.get_data_loaders() + accs = evaluate(model, final_dataset) + + log_accs(args, logger, accs, 'final', final_dataset.SETTING, prefix="FINAL") + + if not args.disable_log and args.enable_other_metrics: + logger.add_bwt(results, results_mask_classes) + logger.add_forgetting(results, results_mask_classes) + if model.NAME != 'icarl' and model.NAME != 'pnn': + logger.add_fwt(results, random_results_class, + results_mask_classes, random_results_task) + + system_tracker.print_stats() if not args.disable_log: logger.write(vars(args))