Skip to content

Commit

Permalink
Merge pull request #40 from aimagelab/dev
Browse files Browse the repository at this point in the history
New ViT backbone, validation, and code optimizations
  • Loading branch information
loribonna authored Jun 16, 2024
2 parents 24a2f6a + 37e1363 commit 88c06ec
Show file tree
Hide file tree
Showing 80 changed files with 3,385 additions and 2,596 deletions.
3 changes: 3 additions & 0 deletions NOTICE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
This project is licensed under the MIT License. However, some files in this repository are under the Apache 2.0 License as noted below:

- backbone/vit.py (Apache 2.0 License)
5 changes: 3 additions & 2 deletions backbone/EfficientNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def round_filters(filters, global_params):
multiplier = global_params.width_coefficient
if not multiplier:
return filters

divisor = global_params.depth_divisor
min_depth = global_params.min_depth
filters *= multiplier
Expand Down Expand Up @@ -244,6 +244,7 @@ def get_same_padding_conv2d(image_size=None):
else:
return partial(Conv2dStaticSamePadding, image_size=image_size)


GlobalParams = collections.namedtuple('GlobalParams', [
'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 'data_format',
'num_classes', 'width_coefficient', 'depth_coefficient', 'depth_divisor',
Expand Down Expand Up @@ -848,7 +849,7 @@ def forward(self, inputs, returnt='out'):
x = self.classifier(x)
if returnt == 'out':
return x
elif returnt == 'all':
elif returnt == 'full':
return (x, feats)

raise NotImplementedError("Unknown return type")
Expand Down
2 changes: 1 addition & 1 deletion backbone/MNISTMLP.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def forward(self, x: torch.Tensor, returnt='out') -> torch.Tensor:

if returnt == 'out':
return out
elif returnt == 'all':
elif returnt == 'full':
return (out, feats)

raise NotImplementedError("Unknown return type")
Expand Down
2 changes: 1 addition & 1 deletion backbone/ResNet18_PNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch.nn.functional as F
from torch.nn.functional import avg_pool2d, relu

from backbone.ResNet18 import BasicBlock, ResNet, conv3x3
from backbone.ResNetBlock import BasicBlock, ResNet, conv3x3
from backbone.utils.modules import AlphaModule, ListModule


Expand Down
18 changes: 17 additions & 1 deletion backbone/ResNet18.py → backbone/ResNetBlock.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def __init__(self, block: BasicBlock, num_blocks: List[int],
self.layer4 = self._make_layer(block, nf * 8, num_blocks[3], stride=2)
self.classifier = nn.Linear(nf * 8 * block.expansion, num_classes)

self.feature_dim = nf * 8 * block.expansion

def to(self, device, **kwargs):
self.device = device
return super().to(device, **kwargs)
Expand Down Expand Up @@ -185,7 +187,7 @@ def forward(self, x: torch.Tensor, returnt='out') -> torch.Tensor:
out_4 if not self.return_prerelu else self.layer4[-1].prerelu
]

raise NotImplementedError("Unknown return type. Must be in ['out', 'features', 'both', 'all'] but got {}".format(returnt))
raise NotImplementedError("Unknown return type. Must be in ['out', 'features', 'both', 'full'] but got {}".format(returnt))


def resnet18(nclasses: int, nf: int = 64) -> ResNet:
Expand All @@ -200,3 +202,17 @@ def resnet18(nclasses: int, nf: int = 64) -> ResNet:
ResNet network
"""
return ResNet(BasicBlock, [2, 2, 2, 2], nclasses, nf)


def resnet34(nclasses: int, nf: int = 64) -> ResNet:
"""
Instantiates a ResNet34 network.
Args:
nclasses: number of output classes
nf: number of filters
Returns:
ResNet network
"""
return ResNet(BasicBlock, [3, 4, 6, 3], nclasses, nf)
4 changes: 3 additions & 1 deletion backbone/ResNet50.py → backbone/ResNetBottleneck.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def __init__(
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Linear(512 * block.expansion, num_classes)

self.feature_dim = 512 * block.expansion

for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
Expand Down Expand Up @@ -242,7 +244,7 @@ def forward(self, x: Tensor, returnt="out") -> Tensor:
elif returnt == 'both':
return (out, feature)

raise NotImplementedError("Unknown return type. Must be in ['out', 'features', 'both', 'all'] but got {}".format(returnt))
raise NotImplementedError("Unknown return type. Must be in ['out', 'features', 'both', 'full'] but got {}".format(returnt))

def set_grad_filter(self, filter_s: str, enable: bool) -> None:
negative_mode = filter_s[0] == '~'
Expand Down
167 changes: 167 additions & 0 deletions backbone/utils/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F


class LoRALayer():
def __init__(
self,
lora_dropout: float,
):
# Optional dropout
if lora_dropout > 0.:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x


class LoRALinear(nn.Linear, LoRALayer):

def __init__(
self,
in_features: int,
out_features: int,
lora_dropout: float = 0.,
fan_in_fan_out: bool = False,
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoRALayer.__init__(self, lora_dropout=lora_dropout)

self.fan_in_fan_out = fan_in_fan_out
self.weight.requires_grad = False
self.reset_parameters()

if fan_in_fan_out:
self.weight.data = self.weight.data.transpose(0, 1)

def reset_parameters(self):
nn.Linear.reset_parameters(self)

def forward(self, x: torch.Tensor, AB: dict = None):

def T(w):
return w.transpose(1, 2) if self.fan_in_fan_out else w

result = F.linear(x, T(self.weight), bias=self.bias)

if AB is not None:
A = None
if isinstance(AB, dict):
B = AB['B']
A = AB.get('A')
else:
B = AB
if A is not None:
return result + (B @ (A @ x.transpose(1, 2).unsqueeze(1))).sum(1).transpose(1, 2)
return result + (B @ x.transpose(1, 2).unsqueeze(1)).sum(1).transpose(1, 2)

return result


class ClipLinear(nn.Linear, LoRALayer):

def __init__(
self,
in_features: int,
out_features: int,
lora_dropout: float = 0.,
fan_in_fan_out: bool = False,
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoRALayer.__init__(self, lora_dropout=lora_dropout)

self.fan_in_fan_out = fan_in_fan_out
self.weight.requires_grad = False
self.reset_parameters()

if fan_in_fan_out:
self.weight.data = self.weight.data.transpose(0, 1)

def reset_parameters(self):
nn.Linear.reset_parameters(self)

def forward(self, x: torch.Tensor, AB: dict = None):

def T(w):
return w.transpose(1, 2) if self.fan_in_fan_out else w

result = F.linear(x, T(self.weight), bias=self.bias)

if AB is not None:
A = None
if isinstance(AB, dict):
B = AB['B']
A = AB.get('A')
else:
B = AB
if A is not None:
res = (B @ (A @ torch.permute(x, (1, 2, 0)).unsqueeze(1))).sum(1)
return result + torch.permute(res, (2, 0, 1))
res = (B @ torch.permute(x, (1, 2, 0)).unsqueeze(1)).sum(1)
return result + torch.permute(res, (2, 0, 1))

return result


class IncrementalClassifier(nn.Module):

def __init__(self, embed_dim: int, nb_classes: int):
"""
Incremental classifier for continual learning.
Args:
embed_dim: int, dimension of the input features.
nb_classes: int, number of classes to classify.
"""

super().__init__()

self.embed_dim = embed_dim

heads = [nn.Linear(embed_dim, nb_classes, bias=True)]

self.heads = nn.ModuleList(heads)
self.old_state_dict = None

for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)

def update(self, nb_classes: int, freeze_old=True):
"""
Add a new head to the classifier.
Args:
nb_classes, number of classes to add.
freeze_old: bool, whether to freeze the old heads.
"""

_fc = nn.Linear(self.embed_dim, nb_classes, bias=True).to(self.get_device())

nn.init.trunc_normal_(_fc.weight, std=.02)
nn.init.constant_(_fc.bias, 0)

if freeze_old:
for param in self.heads.parameters():
param.requires_grad = False

self.heads.append(_fc)

def forward(self, x: torch.Tensor):
"""
Forward pass.
Compute the logits for each head and concatenate them.
Args:
x: torch.Tensor, input features.
"""
return torch.cat([h(x) for h in self.heads], dim=1)
Loading

0 comments on commit 88c06ec

Please sign in to comment.