Skip to content

Commit

Permalink
make it so one can do multiple augment llms including vision
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 17, 2024
1 parent 3beef15 commit c9ed44c
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 32 deletions.
76 changes: 46 additions & 30 deletions CALM_pytorch/CALM.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,23 @@
def SequenceOf(t):
return Union[Tuple[t, ...], List[t]]

def SingularOrMany(t):
return Union[t, SequenceOf(t)]

# helpers

def exists(v):
return v is not None

def default(v, d):
return v if exists(v) else d

def xnor(x, y):
return not (x ^ y)

def cast_tuple(t, length = 1):
return t if is_bearable(t, Sequence) else ((t,) * length)

# freezing llms

@beartype
Expand Down Expand Up @@ -165,32 +174,32 @@ class CALM(Module):
def __init__(
self,
anchor_llm: Module,
augment_llm: Union[Module, List[Module]],
augment_llm: Union[Module, SequenceOf(Module)],
*,
attn_kwargs: dict = dict(
linear_project_context = True,
pre_rmsnorm = True,
flash = True
),
forward_mask_to_augment_llm_key: Optional[str] = None, # if set, will forward the prompt_mask to the augment LLM (in case it is an encoder) with this key
augment_every_num_layers = 4, # in the paper, they do 4
get_augment_transformer_blocks_fn: Callable[[Module], List[Module]] = lambda module: module.blocks,
get_anchor_transformer_blocks_fn: Callable[[Module], List[Module]] = lambda module: module.blocks,
input_shape: SingularOrMany(Optional[Tuple]) = None,
forward_mask_to_augment_llm_key: SingularOrMany(Optional[str]) = None, # if set, will forward the prompt_mask to the augment LLM (in case it is an encoder) with this key
augment_every_num_layers: int = 4, # in the paper, they do 4
augment_extract_layers_fn: SingularOrMany(Optional[Callable[[Module], List[Module]]]) = None,
augment_llm_mask_kwarg: SingularOrMany(Optional[str]) = None,
anchor_extract_layers_fn: Callable[[Module], List[Module]] = None,
augment_transformer_blocks: Optional[Union[List[List[Module]], List[Module]]] = None,
anchor_transformer_blocks: Optional[List[Module]] = None,
anchor_to_augment_blocks: Optional[List[Tuple[List[Module], List[Module]]]] = None,
forward_hook_get_hidden: Union[Literal['input'], Literal['output']] = 'output',
pad_id = -1
forward_hook_get_hidden: SingularOrMany(Union[Literal['input'], Literal['output']]) = 'output',
anchor_forward_hook_get_hidden: Union[Literal['input'], Literal['output']] = 'output',
pad_id: int = -1
):
super().__init__()

# account for single augmentation llm (which is what the paper did)
# in this repo, generalizing it to multiple augmentation llms

if not isinstance(augment_llm, list):
augment_llms = [augment_llm]
else:
augment_llms = augment_llm
augment_llms = cast_tuple(augment_llm)

if exists(augment_transformer_blocks):
if is_bearable(augment_transformer_blocks, List[Module]):
Expand Down Expand Up @@ -225,9 +234,12 @@ def __init__(
get_anchor_blocks_fn = x_transformer_blocks if isinstance(anchor_llm, TransformerWrapper) else get_anchor_transformer_blocks_fn
anchor_transformer_blocks = get_anchor_blocks_fn(self.anchor_llm)

for augment_llm in self.augment_llms:
get_augment_blocks_fn = x_transformer_blocks if isinstance(augment_llm, TransformerWrapper) else get_augment_transformer_blocks_fn
augment_transformer_blocks = [get_augment_blocks_fn(llm) for llm in self.augment_llms]
augment_transformer_blocks = []

for augment_llm, extract in zip(self.augment_llms, augment_extract_layers_fn):
extract = default(extract, x_transformer_blocks if isinstance(augment_llm, TransformerWrapper) else None)
assert exists(extract)
augment_transformer_blocks.append(extract(augment_llm))

# calculation for determining every Nth layer of augmentation layer hiddens is attended to
# in paper, they did every 4th layer of 1 augmentation llm
Expand All @@ -252,12 +264,14 @@ def __init__(

assert len(anchor_to_augment_blocks) == len(self.augment_llms)

forward_hook_get_hidden = cast_tuple(forward_hook_get_hidden, len(self.augment_llms))

# cross attend from anchor to augment llm using module forward hooks

all_anchor_dims = []
all_augment_dims = []

for (anchor_blocks_to_hook, augment_blocks_to_hook), augment_llm in zip(anchor_to_augment_blocks, self.augment_llms):
for (anchor_blocks_to_hook, augment_blocks_to_hook), augment_llm, position, maybe_one_input_shape in zip(anchor_to_augment_blocks, self.augment_llms, forward_hook_get_hidden, input_shape):

# number of cross attention for one augmentation llm

Expand All @@ -271,7 +285,7 @@ def __init__(
temp_hooks = []

def get_shape(shapes_arr, _, inp, out):
hiddens = out if forward_hook_get_hidden == 'output' else inp
hiddens = out if position == 'output' else inp
shapes_arr.append(hiddens.shape[-1])

get_anchor_dims = partial(get_shape, anchor_dims)
Expand All @@ -281,9 +295,15 @@ def get_shape(shapes_arr, _, inp, out):
temp_hooks.append(anchor_block.register_forward_hook(get_anchor_dims))
temp_hooks.append(augment_block.register_forward_hook(get_augment_dims))

dummy_input = torch.ones((1, 1), dtype = torch.long)
self.anchor_llm(dummy_input)
augment_llm(dummy_input)
default_dummy_input = torch.ones((1, 1), dtype = torch.long)

if exists(maybe_one_input_shape):
augment_dummy_input = torch.randn((1, *maybe_one_input_shape))
else:
augment_dummy_input = default_dummy_input

self.anchor_llm(default_dummy_input)
augment_llm(augment_dummy_input)

# unregister temporary hooks

Expand All @@ -295,15 +315,15 @@ def get_shape(shapes_arr, _, inp, out):

# instantiate cross attentions

for anchor_dims, augment_dims, (anchor_blocks_to_hook, augment_blocks_to_hook), augment_llm in zip(all_anchor_dims, all_augment_dims, anchor_to_augment_blocks, self.augment_llms):
for anchor_dims, augment_dims, (anchor_blocks_to_hook, augment_blocks_to_hook), augment_llm, position in zip(all_anchor_dims, all_augment_dims, anchor_to_augment_blocks, self.augment_llms, forward_hook_get_hidden):

recorders = []
one_augment_llm_cross_attns = ModuleList([])

for dim_anchor, dim_augment, _ in zip(anchor_dims, augment_dims, range(num_cross_attns)):
recorder = Recorder(forward_hook_get_hidden = forward_hook_get_hidden)
recorder = Recorder(forward_hook_get_hidden = position)
recorders.append(recorder)
one_augment_llm_cross_attns.append(CrossAttentionBlock(dim = dim_anchor, dim_context = dim_augment, recorder = recorder, forward_hook_get_hidden = forward_hook_get_hidden, **attn_kwargs))
one_augment_llm_cross_attns.append(CrossAttentionBlock(dim = dim_anchor, dim_context = dim_augment, recorder = recorder, forward_hook_get_hidden = anchor_forward_hook_get_hidden, **attn_kwargs))

# connect the two models

Expand Down Expand Up @@ -338,7 +358,7 @@ def forward(
seq: Tensor,
*,
prompt: Union[Tensor, SequenceOf(Tensor)],
prompt_mask: Optional[Union[Tensor, SequenceOf(Tensor)]] = None,
prompt_mask: Optional[SingularOrMany(SequenceOf(Tensor))] = None,
mask: Optional[Tensor] = None,
return_loss = True,
anchor_llm_in_train_mode = True # unsure about this
Expand All @@ -357,20 +377,16 @@ def forward(

num_augment_llms = len(self.augment_llms)

if not is_bearable(prompt, Sequence):
prompts = (prompt,) * num_augment_llms
else:
prompts = prompt
prompts = cast_tuple(prompt, num_augment_llms)

assert len(prompts) == num_augment_llms

# prompt masks

if not exists(prompt_mask):
prompt_mask = tuple(p != self.pad_id for p in prompts)
prompt_mask = tuple((p != self.pad_id if not torch.is_floating_point(p) else None) for p in prompts)

if not is_bearable(prompt_mask, Sequence):
prompt_mask = (prompt_mask,) * num_augment_llms
prompt_mask = cast_tuple(prompt_mask, num_augment_llms)

prompt_masks = prompt_mask # at this point, should be plural

Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,11 @@ calm = CALM(
- [x] take care of finetuning training logic
- [x] extend to a list of augmentation llms
- [x] full connectivity customization
- [x] make simple vit work
- [ ] show example
- [ ] refactor so extraction fn, mask kwarg, and other related hparams are grouped together under a dictionary of {[augment_llm_name]: {augment_llm_related_hparams}}
- [ ] custom number of augmentation layers per augmetation llm
- [ ] move the hook logic for deriving hidden shapes to pytorch-custom-utils for reuse
- [ ] show an example of two augmentation llms with different prompts, one vision transformer, the other text-based

- [ ] handle a wrapper or function that takes in the sequence and prompt length, and auto derives the inputs to CALM
- [ ] show example of manually passing in list of transformer blocks as `List[Module]`. try out with some popular pretrained models
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'CALM-Pytorch',
packages = find_packages(exclude=[]),
version = '0.0.27',
version = '0.0.28',
license='MIT',
description = 'CALM - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit c9ed44c

Please sign in to comment.