diff --git a/CALM_pytorch/CALM.py b/CALM_pytorch/CALM.py index e4a86f8..66586ba 100644 --- a/CALM_pytorch/CALM.py +++ b/CALM_pytorch/CALM.py @@ -41,14 +41,23 @@ def SequenceOf(t): return Union[Tuple[t, ...], List[t]] +def SingularOrMany(t): + return Union[t, SequenceOf(t)] + # helpers def exists(v): return v is not None +def default(v, d): + return v if exists(v) else d + def xnor(x, y): return not (x ^ y) +def cast_tuple(t, length = 1): + return t if is_bearable(t, Sequence) else ((t,) * length) + # freezing llms @beartype @@ -165,32 +174,32 @@ class CALM(Module): def __init__( self, anchor_llm: Module, - augment_llm: Union[Module, List[Module]], + augment_llm: Union[Module, SequenceOf(Module)], *, attn_kwargs: dict = dict( linear_project_context = True, pre_rmsnorm = True, flash = True ), - forward_mask_to_augment_llm_key: Optional[str] = None, # if set, will forward the prompt_mask to the augment LLM (in case it is an encoder) with this key - augment_every_num_layers = 4, # in the paper, they do 4 - get_augment_transformer_blocks_fn: Callable[[Module], List[Module]] = lambda module: module.blocks, - get_anchor_transformer_blocks_fn: Callable[[Module], List[Module]] = lambda module: module.blocks, + input_shape: SingularOrMany(Optional[Tuple]) = None, + forward_mask_to_augment_llm_key: SingularOrMany(Optional[str]) = None, # if set, will forward the prompt_mask to the augment LLM (in case it is an encoder) with this key + augment_every_num_layers: int = 4, # in the paper, they do 4 + augment_extract_layers_fn: SingularOrMany(Optional[Callable[[Module], List[Module]]]) = None, + augment_llm_mask_kwarg: SingularOrMany(Optional[str]) = None, + anchor_extract_layers_fn: Callable[[Module], List[Module]] = None, augment_transformer_blocks: Optional[Union[List[List[Module]], List[Module]]] = None, anchor_transformer_blocks: Optional[List[Module]] = None, anchor_to_augment_blocks: Optional[List[Tuple[List[Module], List[Module]]]] = None, - forward_hook_get_hidden: Union[Literal['input'], Literal['output']] = 'output', - pad_id = -1 + forward_hook_get_hidden: SingularOrMany(Union[Literal['input'], Literal['output']]) = 'output', + anchor_forward_hook_get_hidden: Union[Literal['input'], Literal['output']] = 'output', + pad_id: int = -1 ): super().__init__() # account for single augmentation llm (which is what the paper did) # in this repo, generalizing it to multiple augmentation llms - if not isinstance(augment_llm, list): - augment_llms = [augment_llm] - else: - augment_llms = augment_llm + augment_llms = cast_tuple(augment_llm) if exists(augment_transformer_blocks): if is_bearable(augment_transformer_blocks, List[Module]): @@ -225,9 +234,12 @@ def __init__( get_anchor_blocks_fn = x_transformer_blocks if isinstance(anchor_llm, TransformerWrapper) else get_anchor_transformer_blocks_fn anchor_transformer_blocks = get_anchor_blocks_fn(self.anchor_llm) - for augment_llm in self.augment_llms: - get_augment_blocks_fn = x_transformer_blocks if isinstance(augment_llm, TransformerWrapper) else get_augment_transformer_blocks_fn - augment_transformer_blocks = [get_augment_blocks_fn(llm) for llm in self.augment_llms] + augment_transformer_blocks = [] + + for augment_llm, extract in zip(self.augment_llms, augment_extract_layers_fn): + extract = default(extract, x_transformer_blocks if isinstance(augment_llm, TransformerWrapper) else None) + assert exists(extract) + augment_transformer_blocks.append(extract(augment_llm)) # calculation for determining every Nth layer of augmentation layer hiddens is attended to # in paper, they did every 4th layer of 1 augmentation llm @@ -252,12 +264,14 @@ def __init__( assert len(anchor_to_augment_blocks) == len(self.augment_llms) + forward_hook_get_hidden = cast_tuple(forward_hook_get_hidden, len(self.augment_llms)) + # cross attend from anchor to augment llm using module forward hooks all_anchor_dims = [] all_augment_dims = [] - for (anchor_blocks_to_hook, augment_blocks_to_hook), augment_llm in zip(anchor_to_augment_blocks, self.augment_llms): + for (anchor_blocks_to_hook, augment_blocks_to_hook), augment_llm, position, maybe_one_input_shape in zip(anchor_to_augment_blocks, self.augment_llms, forward_hook_get_hidden, input_shape): # number of cross attention for one augmentation llm @@ -271,7 +285,7 @@ def __init__( temp_hooks = [] def get_shape(shapes_arr, _, inp, out): - hiddens = out if forward_hook_get_hidden == 'output' else inp + hiddens = out if position == 'output' else inp shapes_arr.append(hiddens.shape[-1]) get_anchor_dims = partial(get_shape, anchor_dims) @@ -281,9 +295,15 @@ def get_shape(shapes_arr, _, inp, out): temp_hooks.append(anchor_block.register_forward_hook(get_anchor_dims)) temp_hooks.append(augment_block.register_forward_hook(get_augment_dims)) - dummy_input = torch.ones((1, 1), dtype = torch.long) - self.anchor_llm(dummy_input) - augment_llm(dummy_input) + default_dummy_input = torch.ones((1, 1), dtype = torch.long) + + if exists(maybe_one_input_shape): + augment_dummy_input = torch.randn((1, *maybe_one_input_shape)) + else: + augment_dummy_input = default_dummy_input + + self.anchor_llm(default_dummy_input) + augment_llm(augment_dummy_input) # unregister temporary hooks @@ -295,15 +315,15 @@ def get_shape(shapes_arr, _, inp, out): # instantiate cross attentions - for anchor_dims, augment_dims, (anchor_blocks_to_hook, augment_blocks_to_hook), augment_llm in zip(all_anchor_dims, all_augment_dims, anchor_to_augment_blocks, self.augment_llms): + for anchor_dims, augment_dims, (anchor_blocks_to_hook, augment_blocks_to_hook), augment_llm, position in zip(all_anchor_dims, all_augment_dims, anchor_to_augment_blocks, self.augment_llms, forward_hook_get_hidden): recorders = [] one_augment_llm_cross_attns = ModuleList([]) for dim_anchor, dim_augment, _ in zip(anchor_dims, augment_dims, range(num_cross_attns)): - recorder = Recorder(forward_hook_get_hidden = forward_hook_get_hidden) + recorder = Recorder(forward_hook_get_hidden = position) recorders.append(recorder) - one_augment_llm_cross_attns.append(CrossAttentionBlock(dim = dim_anchor, dim_context = dim_augment, recorder = recorder, forward_hook_get_hidden = forward_hook_get_hidden, **attn_kwargs)) + one_augment_llm_cross_attns.append(CrossAttentionBlock(dim = dim_anchor, dim_context = dim_augment, recorder = recorder, forward_hook_get_hidden = anchor_forward_hook_get_hidden, **attn_kwargs)) # connect the two models @@ -338,7 +358,7 @@ def forward( seq: Tensor, *, prompt: Union[Tensor, SequenceOf(Tensor)], - prompt_mask: Optional[Union[Tensor, SequenceOf(Tensor)]] = None, + prompt_mask: Optional[SingularOrMany(SequenceOf(Tensor))] = None, mask: Optional[Tensor] = None, return_loss = True, anchor_llm_in_train_mode = True # unsure about this @@ -357,20 +377,16 @@ def forward( num_augment_llms = len(self.augment_llms) - if not is_bearable(prompt, Sequence): - prompts = (prompt,) * num_augment_llms - else: - prompts = prompt + prompts = cast_tuple(prompt, num_augment_llms) assert len(prompts) == num_augment_llms # prompt masks if not exists(prompt_mask): - prompt_mask = tuple(p != self.pad_id for p in prompts) + prompt_mask = tuple((p != self.pad_id if not torch.is_floating_point(p) else None) for p in prompts) - if not is_bearable(prompt_mask, Sequence): - prompt_mask = (prompt_mask,) * num_augment_llms + prompt_mask = cast_tuple(prompt_mask, num_augment_llms) prompt_masks = prompt_mask # at this point, should be plural diff --git a/README.md b/README.md index cca1d25..e1a0d20 100644 --- a/README.md +++ b/README.md @@ -108,9 +108,11 @@ calm = CALM( - [x] take care of finetuning training logic - [x] extend to a list of augmentation llms - [x] full connectivity customization + - [x] make simple vit work + - [ ] show example + - [ ] refactor so extraction fn, mask kwarg, and other related hparams are grouped together under a dictionary of {[augment_llm_name]: {augment_llm_related_hparams}} - [ ] custom number of augmentation layers per augmetation llm - [ ] move the hook logic for deriving hidden shapes to pytorch-custom-utils for reuse - - [ ] show an example of two augmentation llms with different prompts, one vision transformer, the other text-based - [ ] handle a wrapper or function that takes in the sequence and prompt length, and auto derives the inputs to CALM - [ ] show example of manually passing in list of transformer blocks as `List[Module]`. try out with some popular pretrained models diff --git a/setup.py b/setup.py index 8cf9571..752479c 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'CALM-Pytorch', packages = find_packages(exclude=[]), - version = '0.0.27', + version = '0.0.28', license='MIT', description = 'CALM - Pytorch', author = 'Phil Wang',