Skip to content

Commit

Permalink
reduce complexity by grouping each augment llm hparams, at the cost o…
Browse files Browse the repository at this point in the history
…f an extra dataclass researchers will need to use when initializing
  • Loading branch information
lucidrains committed Jan 19, 2024
1 parent e483ecb commit 6090fdc
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 80 deletions.
113 changes: 53 additions & 60 deletions CALM_pytorch/CALM.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from functools import partial
from contextlib import nullcontext

from dataclasses import dataclass

import torch
import torch.nn.functional as F
from torch.nn import Module, ModuleList
Expand Down Expand Up @@ -196,51 +198,46 @@ def forward(self, _, inp, out):

# main class

@dataclass
class AugmentParams:
model: Module
hidden_position: Union[Literal['input'], Literal['output']] = 'output'
transformer_blocks: Optional[List[Module]] = None
extract_blocks_fn: Optional[Callable[[Module], List[Module]]] = None
input_shape: Optional[Tuple[int, ...]] = None
connections: Optional[Tuple[Tuple[int, int], ...]] = None
connect_every_num_layers: int = 4 # in the paper, they do 4
mask_kwarg: Optional[str] = None

class CALM(Module):
@beartype
def __init__(
self,
anchor_llm: Module,
augment_llm: Union[Module, SequenceOf(Module)],
augment_llms: SingularOrMany(AugmentParams),
*,
attn_kwargs: dict = dict(
linear_project_context = True,
pre_rmsnorm = True,
flash = True
),
connections: Optional[SingularOrMany(Tuple[Tuple[int, int], ...],)] = None,
input_shape: SingularOrMany(Optional[Tuple]) = None,
forward_mask_to_augment_llm_key: SingularOrMany(Optional[str]) = None, # if set, will forward the prompt_mask to the augment LLM (in case it is an encoder) with this key
augment_every_num_layers: int = 4, # in the paper, they do 4
augment_extract_layers_fn: SingularOrMany(Optional[Callable[[Module], List[Module]]]) = None,
augment_llm_mask_kwarg: SingularOrMany(Optional[str]) = None,
anchor_extract_layers_fn: Callable[[Module], List[Module]] = None,
augment_transformer_blocks: Optional[Union[List[List[Module]], List[Module]]] = None,
anchor_extract_blocks_fn: Callable[[Module], List[Module]] = None,
anchor_transformer_blocks: Optional[List[Module]] = None,
anchor_get_hidden_position: Union[Literal['input'], Literal['output']] = 'output',
augment_get_hidden_positions: SingularOrMany(Union[Literal['input'], Literal['output']]) = 'output',
pad_id: int = -1
):
super().__init__()

# account for single augmentation llm (which is what the paper did)
# in this repo, generalizing it to multiple augmentation llms

augment_llms = cast_tuple(augment_llm)
if isinstance(augment_llms, AugmentParams):
augment_llms = [augment_llms]

if exists(augment_transformer_blocks):
if is_bearable(augment_transformer_blocks, List[Module]):
augment_transformer_blocks = [augment_transformer_blocks]

if exists(connections):
if is_bearable(connections, Tuple[Tuple[int, int], ...]):
connections = [connections]
augment_llms_params = augment_llms

# main contribution of paper
# is showing that both anchor and augment can be frozen, and that cross attention from anchor -> augment every few layers outperforms lora

self.anchor_llm = anchor_llm
self.augment_llms = nn.ModuleList(augment_llms)
self.augment_llms = nn.ModuleList([params.model for params in augment_llms_params])

freeze_all_layers_(self.anchor_llm)
freeze_all_layers_(self.augment_llms)
Expand All @@ -256,18 +253,16 @@ def __init__(
# derive the blocks from the model and extraction function, if not

if not exists(anchor_transformer_blocks):
get_anchor_blocks_fn = x_transformer_blocks if isinstance(anchor_llm, TransformerWrapper) else get_anchor_transformer_blocks_fn
get_anchor_blocks_fn = x_transformer_blocks if isinstance(anchor_llm, TransformerWrapper) else anchor_extract_blocks_fn
anchor_transformer_blocks = get_anchor_blocks_fn(self.anchor_llm)

if not exists(augment_transformer_blocks):
augment_extract_layers_fn = cast_tuple(augment_extract_layers_fn, num_augment_llms)

augment_transformer_blocks = []
for params in augment_llms_params:
if exists(params.transformer_blocks):
continue

for augment_llm, extract in zip(self.augment_llms, augment_extract_layers_fn):
extract = default(extract, x_transformer_blocks if isinstance(augment_llm, TransformerWrapper) else None)
assert exists(extract)
augment_transformer_blocks.append(extract(augment_llm))
extract = default(params.extract_blocks_fn, x_transformer_blocks if isinstance(params.model, TransformerWrapper) else None)
assert exists(extract)
params.transformer_blocks = extract(params.model)

# extract all forward outputs from all transformer blocks
# for sanitizing the input (making sure transformer blocks are ordered by execution)
Expand All @@ -277,18 +272,18 @@ def __init__(

recording_inputs = [default_transformer_input]

input_shapes = cast_tuple(input_shape, num_augment_llms)
for params in augment_llms_params:
maybe_input_shape = params.input_shape

for maybe_one_input_shape in input_shapes:
if exists(maybe_one_input_shape):
inp = torch.randn((1, *maybe_one_input_shape))
if exists(maybe_input_shape):
inp = torch.randn((1, *maybe_input_shape))
else:
inp = default_transformer_input

recording_inputs.append(inp)

all_blocks = [anchor_transformer_blocks, *augment_transformer_blocks]
all_models = [anchor_llm, *augment_llms]
all_blocks = [anchor_transformer_blocks, *[params.transformer_blocks for params in augment_llms_params]]
all_models = [anchor_llm, *self.augment_llms]

all_outputs = [extract_forward_inputs(model, recording_input, blocks) for model, recording_input, blocks in zip(all_models, recording_inputs, all_blocks)]

Expand All @@ -301,31 +296,31 @@ def __init__(
# calculation for determining every Nth layer of augmentation layer hiddens is attended to
# in paper, they did every 4th layer of 1 augmentation llm

if not exists(connections):
connections = []
for params, augment_outputs in zip(augment_llms_params, augments_outputs):
if exists(params.connections):
continue

for augment_outputs in augments_outputs:

one_num_augment_blocks = len(augment_outputs)
one_num_augment_blocks = len(augment_outputs)

num_attended_augment_hiddens = ceil(one_num_augment_blocks / augment_every_num_layers)
num_cross_attending_anchor_blocks = min(num_attended_augment_hiddens, num_anchor_blocks)
anchor_every_num_layers = num_anchor_blocks // num_cross_attending_anchor_blocks
num_attended_augment_hiddens = ceil(one_num_augment_blocks / params.connect_every_num_layers)
num_cross_attending_anchor_blocks = min(num_attended_augment_hiddens, num_anchor_blocks)
anchor_every_num_layers = num_anchor_blocks // num_cross_attending_anchor_blocks

anchor_layer_indices = [*range(0, len(anchor_outputs), anchor_every_num_layers)]
augment_layer_indices = [*range(0, len(one_augment_transformer_blocks), augment_every_num_layers)]
# using 1 indexed, to save on confusion when manually defining connection layer
# (some researchers will probably not understand 0th layer == 1)

connections.append(tuple(zip(anchor_layer_indices, augment_layer_indices)))
anchor_layer_indices = [*range(1, len(anchor_outputs) + 1, anchor_every_num_layers)]
augment_layer_indices = [*range(1, len(augment_outputs) + 1, params.connect_every_num_layers)]

assert len(connections) == num_augment_llms
params.connections = tuple(zip(anchor_layer_indices, augment_layer_indices))

self.connections = connections
self.connections = [params.connections for params in augment_llms_params]

# from connections, get all paired transformer blocks between anchor and augments

anchor_to_augment_outputs = []

for connection, augment_outputs in zip(connections, augments_outputs):
for connection, augment_outputs in zip(self.connections, augments_outputs):

one_num_augment_blocks = len(augment_outputs)

Expand All @@ -341,10 +336,6 @@ def __init__(
(one_anchor_outputs, one_augment_outputs)
)

# for deriving hidden dimensions magically

augment_get_hidden_positions = cast_tuple(augment_get_hidden_positions, num_augment_llms)

# function for getting output or input dimension
# depending on get_hidden_position

Expand All @@ -355,7 +346,9 @@ def get_hidden_dim(hook_output: Tuple[Module, Tensor, Tensor], position: Union[L

# instantiate cross attentions

for (one_anchor_outputs, one_augment_outputs), augment_llm, augment_position in zip(anchor_to_augment_outputs, self.augment_llms, augment_get_hidden_positions):
for (one_anchor_outputs, one_augment_outputs), params in zip(anchor_to_augment_outputs, augment_llms_params):

augment_llm, augment_position = params.model, params.hidden_position

# number of cross attention for one augmentation llm

Expand Down Expand Up @@ -392,7 +385,7 @@ def get_hidden_dim(hook_output: Tuple[Module, Tensor, Tensor], position: Union[L

# forwarding a mask to augment llm

self.forward_mask_to_augment_llm_key = forward_mask_to_augment_llm_key
self.augment_llms_params = augment_llms_params

def state_dict(self):
return self.cross_attns.state_dict()
Expand Down Expand Up @@ -449,11 +442,11 @@ def forward(

self.augment_llms.eval()

for augment_llm, prompt, prompt_mask in zip(self.augment_llms, prompts, prompt_masks):
for augment_llm, params, prompt, prompt_mask in zip(self.augment_llms, self.augment_llms_params, prompts, prompt_masks):
augment_llm_kwarg = dict()

if exists(self.forward_mask_to_augment_llm_key):
augment_llm_kwarg = {self.forward_mask_to_augment_llm_key: prompt_mask}
if exists(params.mask_kwarg):
augment_llm_kwarg = {params.mask_kwarg: prompt_mask}

augment_llm(prompt, **augment_llm_kwarg)

Expand Down
6 changes: 5 additions & 1 deletion CALM_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from CALM_pytorch.CALM import CALM, FineTuner
from CALM_pytorch.CALM import (
AugmentParams,
CALM,
FineTuner
)
42 changes: 24 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ anchor_llm = TransformerWrapper(

# import CALM wrapper

from CALM_pytorch import CALM
from CALM_pytorch import CALM, AugmentParams

calm = CALM(
anchor_llm,
augment_llm,
augment_every_num_layers = 4
augment_llms = AugmentParams(
model = augment_llm,
connect_every_num_layers = 4
)
)

# mock input
Expand Down Expand Up @@ -106,21 +108,25 @@ Say you want to explore different types of connectivity between anchor and augme
```python
calm = CALM(
anchor_llm = anchor_llm,
augment_llm = [augment_llm1, augment_llm2],
connections = [
(
(12, 1), # 12th layer of anchor attends to 1st layer of augment llm1
(12, 2),
(12, 3),
(12, 4),
),
(
(1, 6), # 1st layer of anchor attends to 6th layer of augment llm2
(2, 6),
(12, 12),
augment_llms = (
AugmentParams(
model = augment_llm1,
connections = (
(12, 1), # 12th layer of anchor attends to 1st layer of augment llm1
(12, 2),
(12, 3),
(12, 4),
),
),
# ... and so on, add vision transformers, whatever
]
AugmentParams(
model = augment_llm2,
connections = (
(1, 6), # 1st layer of anchor attends to 6th layer of augment llm2
(2, 6),
(12, 12),
)
)
)
)
```

Expand All @@ -136,8 +142,8 @@ calm = CALM(
- [x] full connectivity customization
- [x] custom number of augmentation layers per augmetation llm
- [x] make simple vit work
- [x] refactor so extraction fn, mask kwarg, and other related hparams are grouped together under a dictionary of {[augment_llm_name]: {augment_llm_related_hparams}} - use dataclasses
- [ ] show example
- [ ] refactor so extraction fn, mask kwarg, and other related hparams are grouped together under a dictionary of {[augment_llm_name]: {augment_llm_related_hparams}} - use `TypedDict` + beartype for validation

- [ ] handle a wrapper or function that takes in the sequence and prompt length, and auto derives the inputs to CALM
- [ ] add an option for self attention path way with memory tokens attending to hidden states of all augmentation llms, akin to what was done with <a href="https://github.com/lucidrains/zorro-pytorch">Zorro</a>
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'CALM-Pytorch',
packages = find_packages(exclude=[]),
version = '0.0.33',
version = '0.1.0',
license='MIT',
description = 'CALM - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 6090fdc

Please sign in to comment.