Skip to content

Commit

Permalink
more intuitive to have connections as (<augment layer>, <anchor layer>)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 31, 2024
1 parent 8b5e816 commit eef4bdb
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
6 changes: 3 additions & 3 deletions CALM_pytorch/CALM.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def __init__(
anchor_layer_indices = [*range(1, len(anchor_outputs) + 1, anchor_every_num_layers)]
augment_layer_indices = [*range(1, len(augment_outputs) + 1, params.connect_every_num_layers)]

params.connections = tuple(zip(anchor_layer_indices, augment_layer_indices))
params.connections = tuple(zip(augment_layer_indices, anchor_layer_indices))

self.connections = [params.connections for params in augment_llms_params]

Expand All @@ -369,7 +369,7 @@ def __init__(
for connection, params, augment_outputs in zip(self.connections, augment_llms_params, augments_outputs):
one_num_augment_blocks = len(augment_outputs)

anchor_layer_indices, augment_layer_indices = tuple(zip(*connection))
augment_layer_indices, anchor_layer_indices = tuple(zip(*connection))

assert all([1 <= i <= len(anchor_outputs) for i in anchor_layer_indices]), 'you specified anchor llm layers outside of actual number of layers'
assert all([1 <= i <= len(augment_outputs) for i in augment_layer_indices]), 'you specified augment llm layers outside of actual number of layers'
Expand Down Expand Up @@ -466,7 +466,7 @@ def forward_augments(

for one_augment_hiddens, one_augment_cross_attns, one_augment_connections in zip(augments_hiddens, self.cross_attns, self.connections):

for (_, augment_layer_index), cross_attn in zip(one_augment_connections, one_augment_cross_attns):
for (augment_layer_index, _), cross_attn in zip(one_augment_connections, one_augment_cross_attns):

cross_attn.context = one_augment_hiddens[augment_layer_index - 1]

Expand Down
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,17 @@ calm = CALM(
AugmentParams(
model = augment_llm1,
connections = (
(12, 1), # 12th layer of anchor attends to 1st layer of augment llm1
(12, 2),
(12, 3),
(12, 4),
(1, 12), # 1st layer of augment llm1 attended to by 12th layer of anchor llm
(2, 12),
(3, 12),
(4, 12),
),
),
AugmentParams(
model = augment_llm2,
connections = (
(1, 6), # 1st layer of anchor attends to 6th layer of augment llm2
(2, 6),
(6, 1), # 6th layer of augment llm2 attended to by 1st layer of anchor llm
(6, 2),
(12, 12),
)
)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'CALM-Pytorch',
packages = find_packages(exclude=[]),
version = '0.1.11',
version = '0.2.0',
license='MIT',
description = 'CALM - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit eef4bdb

Please sign in to comment.