Skip to content

Commit

Permalink
fix bug and show end to end example with anchor llm exposed to 2 spec…
Browse files Browse the repository at this point in the history
…ialized llms + eyes
  • Loading branch information
lucidrains committed Jan 19, 2024
1 parent 80badc2 commit fd7836a
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 5 deletions.
6 changes: 4 additions & 2 deletions CALM_pytorch/CALM.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ def __init__(
self.recorder = recorder
self.context_proj = None

self.dim = dim
self.dim_context = dim_context

if linear_project_context:
self.context_proj = nn.Linear(dim_context, dim)
dim_context = dim
Expand Down Expand Up @@ -328,7 +331,6 @@ def __init__(
anchor_to_augment_outputs = []

for connection, params, augment_outputs in zip(self.connections, augment_llms_params, augments_outputs):

one_num_augment_blocks = len(augment_outputs)

anchor_layer_indices, augment_layer_indices = tuple(zip(*connection))
Expand Down Expand Up @@ -378,7 +380,7 @@ def get_hidden_dim(hook_output: Tuple[Module, Tensor, Tensor], position: Union[L

# connect the two models

for (anchor_block, *_), recorder, cross_attn, (augment_block, *_) in zip(anchor_outputs, recorders, one_augment_llm_cross_attns, augment_outputs):
for ((anchor_block, *_), *_), recorder, cross_attn, ((augment_block, *_), *_) in zip(one_anchor_outputs_and_positions, recorders, one_augment_llm_cross_attns, one_augment_outputs_and_positions):
augment_block.register_forward_hook(recorder)
anchor_block.register_forward_hook(cross_attn)

Expand Down
102 changes: 100 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,108 @@ calm = CALM(
(2, 6),
(12, 12),
)
)
)
)
```

CALM setup with 2 specialized augmentation LLMs + a vision transformer

```python
import torch

# pip install vit-pytorch x-transformers

from vit_pytorch.vit import ViT, Attention
from x_transformers import TransformerWrapper, Encoder, Decoder

anchor_llm = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 16,
dim_head = 2,
depth = 12,
heads = 8
)
)

augment_llm1 = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Encoder(
dim = 16,
dim_head = 2,
depth = 12,
heads = 8
)
)

augment_llm2 = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Encoder(
dim = 16,
dim_head = 2,
depth = 12,
heads = 8
)
)

vit = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 256,
depth = 6,
heads = 16,
mlp_dim = 2048
)

# calm

from CALM_pytorch import CALM, AugmentParams, FineTuner

calm = CALM(
anchor_llm = anchor_llm,
augment_llms = (
AugmentParams(
model = augment_llm1,
mask_kwarg = 'mask'
),
AugmentParams(
model = augment_llm2,
mask_kwarg = 'mask'
),
# ... perhaps other modalities, vision / audio transformer etc
AugmentParams(
model = vit,
input_shape = (3, 256, 256),
extract_blocks_fn = lambda vit: [m for m in vit.modules() if isinstance(m, Attention)]
)
),
attn_kwargs = dict(
linear_project_context = True,
pre_rmsnorm = True,
flash = True
)
)

seq = torch.randint(0, 20000, (1, 1024))
mask = torch.ones((1, 1024)).bool()

prompt = (
torch.randint(0, 20000, (1, 256)),
torch.randint(0, 20000, (1, 256)),
torch.randn(1, 3, 256, 256)
)

loss = calm(
seq,
mask = mask,
prompt = prompt
)

loss.backward()
```

## Todo
Expand All @@ -144,7 +242,7 @@ calm = CALM(
- [x] custom number of augmentation layers per augmetation llm
- [x] make simple vit work
- [x] refactor so extraction fn, mask kwarg, and other related hparams are grouped together under a dictionary of {[augment_llm_name]: {augment_llm_related_hparams}} - use dataclasses
- [ ] show example
- [x] show example

- [ ] when finely specifying hidden positions, make sure to reorder it if the transformer blocks themselves were passed in and not ordered to begin with
- [ ] take care of caching the augment hiddens when sampling. forget about anchor kv cache for now
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'CALM-Pytorch',
packages = find_packages(exclude=[]),
version = '0.1.1',
version = '0.1.2',
license='MIT',
description = 'CALM - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit fd7836a

Please sign in to comment.