Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 20, 2024
1 parent 2c95ea9 commit 781dcd9
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
22 changes: 12 additions & 10 deletions CALM_pytorch/CALM.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@

Sequence = Union[Tuple, List]

HiddenPosition = Union[Literal['input'], Literal['output']]

def SequenceOf(t):
return Union[Tuple[t, ...], List[t]]

Expand All @@ -63,7 +65,10 @@ def cast_tuple(t, length = 1):
def get_indices_of_src_from_tgt(src_arr, tgt_arr):
return [src_arr.index(el) for el in tgt_arr]

def get_block_output_from_hook_outputs(hidden_position, _, inp, out):
def get_block_output_from_hook_outputs(
hidden_position: HiddenPosition,
_, inp, out
):
maybe_tensor = out if hidden_position == 'output' else inp

if isinstance(maybe_tensor, tuple):
Expand Down Expand Up @@ -128,10 +133,7 @@ class Recorder:
@beartype
def __init__(
self,
forward_hook_get_hidden: Union[
Literal['output'],
Literal['input']
] = 'output'
forward_hook_get_hidden: HiddenPosition = 'output'
):
self.output = None
self.get_output_fn = partial(get_block_output_from_hook_outputs, forward_hook_get_hidden)
Expand Down Expand Up @@ -194,8 +196,8 @@ def set_mask(self, mask: Tensor):
def unset_mask(self):
self.context_mask = None

def forward(self, _, inp, out):
x = out if self.forward_hook_get_hidden == 'output' else inp
def forward(self, *hook_args):
x = get_block_output_from_hook_outputs(self.forward_hook_get_hidden, *hook_args)

context = self.recorder.pop_saved()
maybe_enable_grad = torch.enable_grad if self.training else nullcontext
Expand All @@ -216,7 +218,7 @@ def forward(self, _, inp, out):
@dataclass
class AugmentParams:
model: Module
hidden_position: SingularOrMany(Union[Literal['input'], Literal['output']]) = 'output'
hidden_position: SingularOrMany(HiddenPosition) = 'output'
transformer_blocks: Optional[List[Module]] = None
extract_blocks_fn: Optional[Callable[[Module], List[Module]]] = None
input_shape: Optional[Tuple[int, ...]] = None
Expand All @@ -238,7 +240,7 @@ def __init__(
),
anchor_extract_blocks_fn: Callable[[Module], List[Module]] = None,
anchor_transformer_blocks: Optional[List[Module]] = None,
anchor_hidden_position: SingularOrMany(Union[Literal['input'], Literal['output']]) = 'output',
anchor_hidden_position: SingularOrMany(HiddenPosition) = 'output',
pad_id: int = -1
):
super().__init__()
Expand Down Expand Up @@ -373,7 +375,7 @@ def __init__(
# function for getting output or input dimension
# depending on get_hidden_position

def get_hidden_dim(hook_output: Tuple[Module, Tensor, Tensor], position: Union[Literal['input'], Literal['output']]):
def get_hidden_dim(hook_output, position: HiddenPosition):
maybe_tensor = get_block_output_from_hook_outputs(position, *hook_output)
return maybe_tensor.shape[-1]

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'CALM-Pytorch',
packages = find_packages(exclude=[]),
version = '0.1.5',
version = '0.1.6',
license='MIT',
description = 'CALM - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 781dcd9

Please sign in to comment.