Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New ViT backbone, validation, and code optimizations #40

Merged
merged 11 commits into from
Jun 16, 2024
3 changes: 3 additions & 0 deletions NOTICE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
This project is licensed under the MIT License. However, some files in this repository are under the Apache 2.0 License as noted below:

- backbone/vit.py (Apache 2.0 License)
5 changes: 3 additions & 2 deletions backbone/EfficientNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def round_filters(filters, global_params):
multiplier = global_params.width_coefficient
if not multiplier:
return filters

divisor = global_params.depth_divisor
min_depth = global_params.min_depth
filters *= multiplier
Expand Down Expand Up @@ -244,6 +244,7 @@ def get_same_padding_conv2d(image_size=None):
else:
return partial(Conv2dStaticSamePadding, image_size=image_size)


GlobalParams = collections.namedtuple('GlobalParams', [
'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 'data_format',
'num_classes', 'width_coefficient', 'depth_coefficient', 'depth_divisor',
Expand Down Expand Up @@ -848,7 +849,7 @@ def forward(self, inputs, returnt='out'):
x = self.classifier(x)
if returnt == 'out':
return x
elif returnt == 'all':
elif returnt == 'full':
return (x, feats)

raise NotImplementedError("Unknown return type")
Expand Down
2 changes: 1 addition & 1 deletion backbone/MNISTMLP.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def forward(self, x: torch.Tensor, returnt='out') -> torch.Tensor:

if returnt == 'out':
return out
elif returnt == 'all':
elif returnt == 'full':
return (out, feats)

raise NotImplementedError("Unknown return type")
Expand Down
2 changes: 1 addition & 1 deletion backbone/ResNet18_PNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch.nn.functional as F
from torch.nn.functional import avg_pool2d, relu

from backbone.ResNet18 import BasicBlock, ResNet, conv3x3
from backbone.ResNetBlock import BasicBlock, ResNet, conv3x3
from backbone.utils.modules import AlphaModule, ListModule


Expand Down
18 changes: 17 additions & 1 deletion backbone/ResNet18.py → backbone/ResNetBlock.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def __init__(self, block: BasicBlock, num_blocks: List[int],
self.layer4 = self._make_layer(block, nf * 8, num_blocks[3], stride=2)
self.classifier = nn.Linear(nf * 8 * block.expansion, num_classes)

self.feature_dim = nf * 8 * block.expansion

def to(self, device, **kwargs):
self.device = device
return super().to(device, **kwargs)
Expand Down Expand Up @@ -185,7 +187,7 @@ def forward(self, x: torch.Tensor, returnt='out') -> torch.Tensor:
out_4 if not self.return_prerelu else self.layer4[-1].prerelu
]

raise NotImplementedError("Unknown return type. Must be in ['out', 'features', 'both', 'all'] but got {}".format(returnt))
raise NotImplementedError("Unknown return type. Must be in ['out', 'features', 'both', 'full'] but got {}".format(returnt))


def resnet18(nclasses: int, nf: int = 64) -> ResNet:
Expand All @@ -200,3 +202,17 @@ def resnet18(nclasses: int, nf: int = 64) -> ResNet:
ResNet network
"""
return ResNet(BasicBlock, [2, 2, 2, 2], nclasses, nf)


def resnet34(nclasses: int, nf: int = 64) -> ResNet:
"""
Instantiates a ResNet34 network.

Args:
nclasses: number of output classes
nf: number of filters

Returns:
ResNet network
"""
return ResNet(BasicBlock, [3, 4, 6, 3], nclasses, nf)
4 changes: 3 additions & 1 deletion backbone/ResNet50.py → backbone/ResNetBottleneck.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def __init__(
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Linear(512 * block.expansion, num_classes)

self.feature_dim = 512 * block.expansion

for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
Expand Down Expand Up @@ -242,7 +244,7 @@ def forward(self, x: Tensor, returnt="out") -> Tensor:
elif returnt == 'both':
return (out, feature)

raise NotImplementedError("Unknown return type. Must be in ['out', 'features', 'both', 'all'] but got {}".format(returnt))
raise NotImplementedError("Unknown return type. Must be in ['out', 'features', 'both', 'full'] but got {}".format(returnt))

def set_grad_filter(self, filter_s: str, enable: bool) -> None:
negative_mode = filter_s[0] == '~'
Expand Down
167 changes: 167 additions & 0 deletions backbone/utils/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F


class LoRALayer():
def __init__(
self,
lora_dropout: float,
):
# Optional dropout
if lora_dropout > 0.:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x


class LoRALinear(nn.Linear, LoRALayer):

def __init__(
self,
in_features: int,
out_features: int,
lora_dropout: float = 0.,
fan_in_fan_out: bool = False,
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoRALayer.__init__(self, lora_dropout=lora_dropout)

self.fan_in_fan_out = fan_in_fan_out
self.weight.requires_grad = False
self.reset_parameters()

if fan_in_fan_out:
self.weight.data = self.weight.data.transpose(0, 1)

def reset_parameters(self):
nn.Linear.reset_parameters(self)

def forward(self, x: torch.Tensor, AB: dict = None):

def T(w):
return w.transpose(1, 2) if self.fan_in_fan_out else w

result = F.linear(x, T(self.weight), bias=self.bias)

if AB is not None:
A = None
if isinstance(AB, dict):
B = AB['B']
A = AB.get('A')
else:
B = AB
if A is not None:
return result + (B @ (A @ x.transpose(1, 2).unsqueeze(1))).sum(1).transpose(1, 2)
return result + (B @ x.transpose(1, 2).unsqueeze(1)).sum(1).transpose(1, 2)

return result


class ClipLinear(nn.Linear, LoRALayer):

def __init__(
self,
in_features: int,
out_features: int,
lora_dropout: float = 0.,
fan_in_fan_out: bool = False,
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoRALayer.__init__(self, lora_dropout=lora_dropout)

self.fan_in_fan_out = fan_in_fan_out
self.weight.requires_grad = False
self.reset_parameters()

if fan_in_fan_out:
self.weight.data = self.weight.data.transpose(0, 1)

def reset_parameters(self):
nn.Linear.reset_parameters(self)

def forward(self, x: torch.Tensor, AB: dict = None):

def T(w):
return w.transpose(1, 2) if self.fan_in_fan_out else w

result = F.linear(x, T(self.weight), bias=self.bias)

if AB is not None:
A = None
if isinstance(AB, dict):
B = AB['B']
A = AB.get('A')
else:
B = AB
if A is not None:
res = (B @ (A @ torch.permute(x, (1, 2, 0)).unsqueeze(1))).sum(1)
return result + torch.permute(res, (2, 0, 1))
res = (B @ torch.permute(x, (1, 2, 0)).unsqueeze(1)).sum(1)
return result + torch.permute(res, (2, 0, 1))

return result


class IncrementalClassifier(nn.Module):

def __init__(self, embed_dim: int, nb_classes: int):
"""
Incremental classifier for continual learning.

Args:
embed_dim: int, dimension of the input features.
nb_classes: int, number of classes to classify.
"""

super().__init__()

self.embed_dim = embed_dim

heads = [nn.Linear(embed_dim, nb_classes, bias=True)]

self.heads = nn.ModuleList(heads)
self.old_state_dict = None

for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)

def update(self, nb_classes: int, freeze_old=True):
"""
Add a new head to the classifier.

Args:
nb_classes, number of classes to add.
freeze_old: bool, whether to freeze the old heads.
"""

_fc = nn.Linear(self.embed_dim, nb_classes, bias=True).to(self.get_device())

nn.init.trunc_normal_(_fc.weight, std=.02)
nn.init.constant_(_fc.bias, 0)

if freeze_old:
for param in self.heads.parameters():
param.requires_grad = False

self.heads.append(_fc)

def forward(self, x: torch.Tensor):
"""
Forward pass.

Compute the logits for each head and concatenate them.

Args:
x: torch.Tensor, input features.
"""
return torch.cat([h(x) for h in self.heads], dim=1)
Loading