Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new datasets and models #43

Merged
merged 76 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
e8a3ad6
Merge pull request #34 from grupposimo/dev
loribonna Jun 15, 2024
e2e2d36
Merge pull request #35 from grupposimo/dev
loribonna Jun 16, 2024
ead5a55
Add savecheck different modes
loribonna Jun 19, 2024
c1f3e94
Add test environment initialization and debug environment. Minor fixe…
loribonna Jun 28, 2024
2628c43
minor updates
loribonna Jun 28, 2024
0529fb3
Add initial stuff for starprompt
loribonna Jul 2, 2024
5b49d0a
minor fix load keys. update device set for backbone
loribonna Jul 3, 2024
8e12101
updated requirements with new pytorch version for scaled dot prod
loribonna Jul 3, 2024
0ece353
Minor update
loribonna Jul 3, 2024
ef1747d
More minor updates
loribonna Jul 3, 2024
6808c80
Fix autopep error
loribonna Jul 3, 2024
9c03dcd
add cars196
loribonna Jul 3, 2024
4b4690d
add chestx
loribonna Jul 3, 2024
927264e
add cropdisease and eurosat. add templates to clip. update vit
loribonna Jul 3, 2024
af23d29
add isic, mit67, and resisc45
loribonna Jul 3, 2024
03d7eb9
updated savecheck jobid. add templates to c10/100-224. Minor changes …
loribonna Jul 4, 2024
f349e89
Merge remote-tracking branch 'grupposimo/dev' into starprompt
loribonna Jul 4, 2024
53c715a
Update imgr,cub200
loribonna Jul 4, 2024
4ceb7a0
updated defaults, minor fixes
loribonna Jul 4, 2024
6621709
minor update stats if no gpu
loribonna Jul 4, 2024
4ce4f9c
cont fix stats with no gpu
loribonna Jul 4, 2024
db0d565
re: fix stats with no gpu
loribonna Jul 4, 2024
9a4fda6
Add prefix tuning option
loribonna Jul 4, 2024
af42363
upd cars
loribonna Jul 5, 2024
d12742f
fix cars transform. update args docs
loribonna Jul 5, 2024
a6be09b
add resize to kornia transform. minor updates
loribonna Jul 5, 2024
2960cce
fix gpu mem measure
loribonna Jul 5, 2024
8fbf5ad
add use_data_aug icoop
loribonna Jul 5, 2024
3c0506d
add gaussian mode for second stage starprompt
loribonna Jul 5, 2024
6ae20b2
maybe fix some weird bug with dataloader workers
loribonna Jul 5, 2024
9f13a32
add seeded dataloader to models
loribonna Jul 5, 2024
c895f2d
updated worker init fn
loribonna Jul 5, 2024
58cd143
fix mog generation with nans
loribonna Jul 5, 2024
1db0122
fix second stage prompt without keys
loribonna Jul 7, 2024
1216bfa
Minor upd
loribonna Jul 8, 2024
28dcf64
fix cars split
loribonna Jul 8, 2024
8ea33ac
debugging cars
loribonna Jul 8, 2024
4cfbf6d
minor change
loribonna Jul 8, 2024
0ebf081
removed unnecessary default
loribonna Jul 9, 2024
83d1353
Added AttriCLIP
MartinMenabue Jul 9, 2024
b6cf71c
updated starprompt keys load
loribonna Jul 9, 2024
8e2513c
fix denorm for tensors
loribonna Jul 9, 2024
826f9a5
updated default for permute_classes
loribonna Jul 9, 2024
7d3cded
minor fix denorm
loribonna Jul 9, 2024
2b16d07
fix cropdisease epochs
loribonna Jul 9, 2024
107d1f5
fix clip performance when loading second stage
loribonna Jul 9, 2024
0e35358
add data aug query in train
loribonna Jul 10, 2024
a8161bb
add enable_data_aug_query. fix default
loribonna Jul 10, 2024
bd2b924
minor fix get device
loribonna Jul 10, 2024
52d8d6c
fix `compute_offsets`
loribonna Jul 10, 2024
22b72eb
didn't i just fix it?
loribonna Jul 10, 2024
c5331d2
upd c100 224
loribonna Jul 10, 2024
a0f02d4
Fixed keys loading on joint
MartinMenabue Jul 11, 2024
f131994
using test aug in eval
loribonna Jul 11, 2024
3a100ef
Merge branch 'master' of https://github.com/loribonna/starprompt
loribonna Jul 11, 2024
f0c7033
add check for scaled dot prod
loribonna Jul 12, 2024
87246d3
Undo fix keys loading on joint
MartinMenabue Jul 12, 2024
3311f98
Fix second-stage orthogonalization init. Code reproduces. Add minor c…
loribonna Jul 19, 2024
dc04407
Autopep
loribonna Jul 19, 2024
b211162
Minor update docs, add tests for starprompt
loribonna Jul 19, 2024
82c9db0
Add end-to-end starprompt (beta)
loribonna Jul 20, 2024
0118411
Merge branch 'master' of https://github.com/loribonna/starprompt
loribonna Jul 20, 2024
5156f0c
minor fix starprompt
loribonna Jul 22, 2024
716a5bb
Minor changes
loribonna Jul 23, 2024
f0d7da6
add cgil, add tqdm instead of progbar and logging.info
apanariello4 Jul 24, 2024
324bc26
Merge branch 'master' of https://github.com/grupposimo/mammoth
loribonna Jul 24, 2024
6f48e15
Merge remote-tracking branch 'upstream/dev'
loribonna Jul 24, 2024
5569c91
uniformed logging. Update e2e, first, and second stage starprompt. up…
loribonna Jul 24, 2024
c4a4171
Add args to models in docs. Minor changes
loribonna Jul 24, 2024
3ee1940
fix epochs first stage parsing
loribonna Jul 24, 2024
59a4638
Updated docs. Add starprompt tests.
loribonna Jul 25, 2024
0a5e53f
Add cub200 with resnet50. Updated tests and docs. Minor fixes. Tests …
loribonna Jul 25, 2024
14a2e51
Merge pull request #36 from grupposimo/dev
loribonna Jul 25, 2024
ff0a335
Updated requirements
loribonna Jul 25, 2024
c92c4f9
Merge branch 'master' of https://github.com/grupposimo/mammoth
loribonna Jul 25, 2024
6349dd0
Merge branch 'aimagelab:master' into master
loribonna Jul 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/deploy_pages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
python-version: "3.10"
- name: Install dependencies
run: |
pip install -r docs/requirements.txt -r requirements.txt
pip install -r docs/requirements.txt -r requirements.txt -r requirements-optional.txt
pip install quadprog==0.1.11
- name: Sphinx build
run: |
Expand Down
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,9 @@ logs
**/_build/
_autosummary
generated
val_permutations
val_permutations

# Other prepare grid scripts except the example one
scripts/prepare_grid*

docs/models/*_args.rst
221 changes: 178 additions & 43 deletions README.md

Large diffs are not rendered by default.

5 changes: 0 additions & 5 deletions backbone/MNISTMLP.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,3 @@ def forward(self, x: torch.Tensor, returnt='out') -> torch.Tensor:
return (out, feats)

raise NotImplementedError("Unknown return type")

def to(self, device):
super().to(device)
self.device = device
return self
4 changes: 0 additions & 4 deletions backbone/ResNetBlock.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,6 @@ def __init__(self, block: BasicBlock, num_blocks: List[int],

self.feature_dim = nf * 8 * block.expansion

def to(self, device, **kwargs):
self.device = device
return super().to(device, **kwargs)

def set_return_prerelu(self, enable=True):
self.return_prerelu = enable
for c in self.modules():
Expand Down
6 changes: 6 additions & 0 deletions backbone/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ class MammothBackbone(nn.Module):

def __init__(self, **kwargs) -> None:
super(MammothBackbone, self).__init__()
self.device = torch.device('cpu') if 'device' not in kwargs else kwargs['device']

def to(self, device, *args, **kwargs):
super(MammothBackbone, self).to(device, *args, **kwargs)
self.device = device
return self

def forward(self, x: torch.Tensor, returnt='out') -> torch.Tensor:
"""
Expand Down
56 changes: 1 addition & 55 deletions backbone/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,61 +6,7 @@
import torch.nn as nn
import torch.nn.functional as F


class LoRALayer():
def __init__(
self,
lora_dropout: float,
):
# Optional dropout
if lora_dropout > 0.:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x


class LoRALinear(nn.Linear, LoRALayer):

def __init__(
self,
in_features: int,
out_features: int,
lora_dropout: float = 0.,
fan_in_fan_out: bool = False,
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoRALayer.__init__(self, lora_dropout=lora_dropout)

self.fan_in_fan_out = fan_in_fan_out
self.weight.requires_grad = False
self.reset_parameters()

if fan_in_fan_out:
self.weight.data = self.weight.data.transpose(0, 1)

def reset_parameters(self):
nn.Linear.reset_parameters(self)

def forward(self, x: torch.Tensor, AB: dict = None):

def T(w):
return w.transpose(1, 2) if self.fan_in_fan_out else w

result = F.linear(x, T(self.weight), bias=self.bias)

if AB is not None:
A = None
if isinstance(AB, dict):
B = AB['B']
A = AB.get('A')
else:
B = AB
if A is not None:
return result + (B @ (A @ x.transpose(1, 2).unsqueeze(1))).sum(1).transpose(1, 2)
return result + (B @ x.transpose(1, 2).unsqueeze(1)).sum(1).transpose(1, 2)

return result
from backbone.utils.lora_utils import LoRALayer


class ClipLinear(nn.Linear, LoRALayer):
Expand Down
208 changes: 208 additions & 0 deletions backbone/utils/lora_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
import collections.abc
from itertools import repeat
from torch import nn
import torch
import torch.nn.functional as F


def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse


to_2tuple = _ntuple(2)


class LoRALayer():
def __init__(
self,
lora_dropout: float,
):
# Optional dropout
if lora_dropout > 0.:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x


class LoRALinear(nn.Linear, LoRALayer):

def __init__(
self,
in_features: int,
out_features: int,
lora_dropout: float = 0.,
fan_in_fan_out: bool = False,
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoRALayer.__init__(self, lora_dropout=lora_dropout)

self.fan_in_fan_out = fan_in_fan_out
self.weight.requires_grad = False
self.reset_parameters()

if fan_in_fan_out:
self.weight.data = self.weight.data.transpose(0, 1)

def reset_parameters(self):
nn.Linear.reset_parameters(self)

def forward(self, x: torch.Tensor, AB: dict = None):

def T(w):
return w.transpose(1, 2) if self.fan_in_fan_out else w

result = F.linear(x, T(self.weight), bias=self.bias)

if AB is not None:
A = None
if isinstance(AB, dict):
B = AB['B']
A = AB.get('A')
else:
B = AB
if A is not None:
return result + (B @ (A @ x.transpose(1, 2).unsqueeze(1))).sum(1).transpose(1, 2)
return result + (B @ x.transpose(1, 2).unsqueeze(1)).sum(1).transpose(1, 2)

return result


class LoRAAttention(nn.Module):
"""
Attention layer as used in Vision Transformer.
Adapted to support LoRA-style parameters.

Args:
dim: Number of input channels
num_heads: Number of attention heads
qkv_bias: If True, add a learnable bias to q, k, v
attn_drop: Dropout rate for attention weights
proj_drop: Dropout rate after the final projection
"""

def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5

self.qkv = LoRALinear(dim, dim * 3, 0., bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = LoRALinear(dim, dim, 0.)
self.proj_drop = nn.Dropout(proj_drop)

def forward(self, x, AB: dict = None, **kwargs):
"""
Forward pass of the attention layer.
Supports `AB` for LoRA-style parameters (checkout docs for `VisionTransformer.forward`).

Args:
x: Input tensor
AB: Dictionary containing LoRA-style parameters for the layer
"""

B, N, C = x.shape

AB_qkv = None

if AB is not None:
AB_qkv = AB.get("qkv")

qkv = self.qkv(x, AB_qkv)
qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)

# NOTE: flash attention is less debuggable than the original. Use the commented code below if in trouble.
if torch.__version__ >= '2.1.0':
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale, dropout_p=self.attn_drop.p)
else:
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v)

x = x.transpose(1, 2).reshape(B, N, C)

AB_proj = None

if AB is not None:
AB_proj = AB.get("proj")

x = self.proj(x, AB_proj)
x = self.proj_drop(x)

return x


class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))

def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma


class LoRAMlp(nn.Module):
"""
MLP as used in Vision Transformer, MLP-Mixer and related networks.
Adapted to support LoRA-style parameters.
"""

def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)

assert use_conv is False

self.fc1 = LoRALinear(in_features, hidden_features, bias=bias[0], lora_dropout=0.)
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.fc2 = LoRALinear(hidden_features, out_features, bias=bias[1], lora_dropout=0.)
self.drop2 = nn.Dropout(drop_probs[1])

def forward(self, x: torch.Tensor, AB: dict = None, **kwargs):
"""
Forward pass of the MLP layer.
Supports `AB` for LoRA-style parameters (checkout docs for `VisionTransformer.forward`).

Args:
x: Input tensor
AB: Dictionary containing LoRA-style parameters for the layer
"""
AB_fc1 = None
AB_fc2 = None

if AB is not None:
AB_fc1 = AB.get("fc1")
AB_fc2 = AB.get("fc2")

x = self.fc1(x, AB_fc1)
x = self.act(x)
x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x, AB_fc2)
x = self.drop2(x)

return x
Loading