Skip to content

Commit

Permalink
cleanup types
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 12, 2024
1 parent 9182050 commit 92f8521
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 29 deletions.
56 changes: 28 additions & 28 deletions CALM_pytorch/CALM.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from math import ceil
from pathlib import Path
from functools import partial
Expand All @@ -14,7 +16,7 @@

from beartype import beartype
from beartype.door import is_bearable
from beartype.typing import List, Optional, Callable, Type, Tuple, Union, Literal
from beartype.typing import Callable, Type, Literal

from einops import rearrange, repeat

Expand Down Expand Up @@ -43,15 +45,15 @@

# types

Sequence = Union[Tuple, List]
Sequence = tuple | list

HiddenPosition = Union[Literal['input'], Literal['output']]
HiddenPosition = Literal['input', 'output']

def SequenceOf(t):
return Union[Tuple[t, ...], List[t]]
return tuple[t, ...] | list[t]

def SingularOrMany(t):
return Union[t, SequenceOf(t)]
return t | SequenceOf(t)

# helpers

Expand Down Expand Up @@ -96,7 +98,7 @@ def freeze_all_layers_(module):
# ex. for x-transformers TransformerWrapper

@beartype
def x_transformer_blocks(transformer: TransformerWrapper) -> List[Module]:
def x_transformer_blocks(transformer: TransformerWrapper) -> list[Module]:
blocks = []
for layer in transformer.attn_layers.layers:
blocks.append(layer[-1])
Expand All @@ -108,9 +110,9 @@ class Recorder:
@beartype
def __init__(
self,
outputs: Optional[List] = None,
outputs: list | None = None,
forward_hook_get_hidden: HiddenPosition = 'output',
modules: Optional[List] = None,
modules: list | None = None,
):
self.output = default(outputs, [])
self.modules = modules
Expand All @@ -129,7 +131,7 @@ class ExtractHiddensWrapper(Module):
def __init__(
self,
model: Module,
blocks: List[Module],
blocks: list[Module],
hidden_positions: SingularOrMany(HiddenPosition) = 'output'
):
super().__init__()
Expand Down Expand Up @@ -171,10 +173,7 @@ def __init__(
dim_context,
linear_project_context = True, # in the paper, they do a projection on the augmented hidden states. not sure if this is needed though, but better to be accurate first
pre_rmsnorm = False,
forward_hook_get_hidden: Union[
Literal['output'],
Literal['input']
] = 'output',
forward_hook_get_hidden: HiddenPosition = 'output',
**kwargs
):
super().__init__()
Expand Down Expand Up @@ -232,13 +231,13 @@ def forward(self, *hook_args):
class AugmentParams:
model: Module
hidden_position: SingularOrMany(HiddenPosition) = 'output'
transformer_blocks: Optional[List[Module]] = None
extract_blocks_fn: Optional[Callable[[Module], List[Module]]] = None
transformer_blocks: list[Module] | None = None
extract_blocks_fn: Callable[[Module], list[Module]] | None = None
model_return_hiddens: bool = False
input_shape: Optional[Tuple[int, ...]] = None
connections: Optional[Tuple[Tuple[int, int], ...]] = None
input_shape: tuple[int, ...] | None = None
connections: tuple[tuple[int, int], ...] | None = None
connect_every_num_layers: int = 4 # in the paper, they do 4
mask_kwarg: Optional[str] = None
mask_kwarg: str | None = None

class CALM(Module):
@beartype
Expand All @@ -252,8 +251,8 @@ def __init__(
pre_rmsnorm = True,
flash = True
),
anchor_extract_blocks_fn: Callable[[Module], List[Module]] = None,
anchor_transformer_blocks: Optional[List[Module]] = None,
anchor_extract_blocks_fn: Callable[[Module], list[Module]] = None,
anchor_transformer_blocks: list[Module] | None = None,
anchor_hidden_position: SingularOrMany(HiddenPosition) = 'output',
pad_id: int = -1
):
Expand All @@ -265,12 +264,12 @@ def __init__(
augment_llms_params = augment_llms

self.anchor_llm = anchor_llm
self.augment_llms = nn.ModuleList([])
self.augment_llms = ModuleList([])

# the only parameters being learned are a bunch of cross attention layers
# attending from anchor to augmentation model(s)

self.cross_attns = nn.ModuleList([])
self.cross_attns = ModuleList([])

# determine the transformer blocks involved
# derive the blocks from the model and extraction function, if not
Expand All @@ -280,6 +279,7 @@ def __init__(
anchor_transformer_blocks = get_anchor_blocks_fn(self.anchor_llm)

anchor_hidden_position = cast_tuple(anchor_hidden_position, len(anchor_transformer_blocks))

assert len(anchor_transformer_blocks) == len(anchor_hidden_position)

# wrap each augment llm with a wrapper that extracts the hiddens
Expand Down Expand Up @@ -447,7 +447,7 @@ def release_cross_attn_contexts(self):
def forward_augments(
self,
prompt: Tensor,
prompt_mask: Optional[SingularOrMany(SequenceOf(Tensor))] = None
prompt_mask: SingularOrMany(SequenceOf(Tensor)) | None = None
):
# if only one prompt is given with multiple augmentation llms, then just feed that one prompt into all augment llm

Expand Down Expand Up @@ -518,7 +518,7 @@ def generate(
self,
prompt: Tensor,
seq_len: int,
prompt_mask: Optional[SingularOrMany(SequenceOf(Tensor))] = None,
prompt_mask: SingularOrMany(SequenceOf(Tensor)) | None = None,
filter_fn: Callable = top_p,
filter_kwargs: dict = dict(
thres = 0.9
Expand Down Expand Up @@ -554,8 +554,8 @@ def forward(
seq: Tensor,
*,
prompt: SingularOrMany(Tensor),
prompt_mask: Optional[SingularOrMany(Tensor)] = None,
mask: Optional[Tensor] = None,
prompt_mask: SingularOrMany(Tensor) | None = None,
mask: Tensor | None = None,
return_loss = True,
anchor_llm_in_train_mode = True # unsure about this
):
Expand Down Expand Up @@ -622,11 +622,11 @@ def __init__(
weight_decay: float,
batch_size: int,
dataset: Dataset,
data_kwarg_names: Tuple[str, ...] = ('seq', 'mask', 'prompt'),
data_kwarg_names: tuple[str, ...] = ('seq', 'mask', 'prompt'),
accelerate_kwargs: dict = dict(),
checkpoint_every: int = 1000,
checkpoint_path: str = './checkpoints',
scheduler: Optional[Type[_LRScheduler]] = None,
scheduler: Type[_LRScheduler] | None = None,
scheduler_kwargs: dict = dict(),
warmup_steps: int = 1000,
max_grad_norm = 0.5,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'CALM-Pytorch',
packages = find_packages(exclude=[]),
version = '0.2.1',
version = '0.2.2',
license='MIT',
description = 'CALM - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 92f8521

Please sign in to comment.