Skip to content

Commit

Permalink
protect against unordered list of modules passed in for anchor
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 1, 2024
1 parent eef4bdb commit 7de578e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
36 changes: 30 additions & 6 deletions CALM_pytorch/CALM.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,18 @@ class Recorder:
def __init__(
self,
outputs: Optional[List] = None,
forward_hook_get_hidden: HiddenPosition = 'output'
forward_hook_get_hidden: HiddenPosition = 'output',
modules: Optional[List] = None,
):
self.output = default(outputs, [])
self.modules = modules
self.get_output_fn = partial(get_block_output_from_hook_outputs, forward_hook_get_hidden)

def __call__(self, *args):

if exists(self.modules):
self.modules.append(args[0])

hidden = self.get_output_fn(*args)
self.output.append(hidden.detach())

Expand All @@ -133,19 +139,27 @@ def __init__(
self.model = model

self.outputs = []
self.modules = []
self.recorders = []

for block, hidden_position in zip(blocks, hidden_positions):
recorder = Recorder(self.outputs, hidden_position)
recorder = Recorder(self.outputs, hidden_position, self.modules)
self.recorders.append(recorder)
block.register_forward_hook(recorder)

def forward(self, *args, **kwargs):
def forward(self, *args, return_hooked_modules = False, **kwargs):
self.model(*args, **kwargs)

outputs = self.outputs.copy()
modules = self.modules.copy()

self.outputs.clear()
return outputs
self.modules.clear()

if not return_hooked_modules:
return outputs

return outputs, modules

# cross attention wrapper class

Expand Down Expand Up @@ -271,12 +285,24 @@ def __init__(
# wrap each augment llm with a wrapper that extracts the hiddens
# if the augment llm is already modified to return a List[Tensor], set model_return_hiddens = True

default_transformer_input = torch.ones((1, 1), dtype = torch.long)

wrapped_anchor_llm = ExtractHiddensWrapper(
anchor_llm,
anchor_transformer_blocks,
anchor_hidden_position
)

# order the anchor transformer blocks by their execution order
# there is not a guarantee that the function or list of modules provided is in the right order
# just remove that gotcha

_, anchor_transformer_blocks = wrapped_anchor_llm(default_transformer_input, return_hooked_modules = True)

assert len(anchor_transformer_blocks) > 0

# process each augment llm and wrap them if necessary

for params in augment_llms_params:

if params.model_return_hiddens:
Expand Down Expand Up @@ -314,8 +340,6 @@ def __init__(
# for sanitizing the input (making sure transformer blocks are ordered by execution)
# and for magically determining hidden dimensions for cross attention

default_transformer_input = torch.ones((1, 1), dtype = torch.long)

recording_inputs = [default_transformer_input]

for params in augment_llms_params:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'CALM-Pytorch',
packages = find_packages(exclude=[]),
version = '0.2.0',
version = '0.2.1',
license='MIT',
description = 'CALM - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 7de578e

Please sign in to comment.