Skip to content

Commit

Permalink
Start fixing modeldiff tests
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Dec 11, 2023
1 parent 73cf036 commit 41ca68d
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 16 deletions.
11 changes: 7 additions & 4 deletions tests/modeldiffs/wmt/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,6 @@ def sd_transform(sd):
k_str = ''.join(k)
if 'SelfAttention' in k_str:
new_key = list(k)
new_key = [
i if i != 'SelfAttention_1' else 'MultiHeadDotProductAttention_0'
for i in new_key
]
if 'SelfAttention_0' in k_str:
if new_key[-2] == 'Dense_0':
# qkv
Expand All @@ -77,6 +73,13 @@ def sd_transform(sd):
# out
out[(*new_key[:-2], 'out', new_key[-1])] = sd[k]
pass
out = {
tuple(
k.replace('SelfAttention', 'MultiHeadDotProductAttention')
for k in key): value
for key,
value in out.items()
}
elif 'Dense' in k_str:
new_key = (*k[:2], 'MlpBlock_0', *k[2:])
out[new_key] = sd[k]
Expand Down
11 changes: 7 additions & 4 deletions tests/modeldiffs/wmt_attention_temp/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,6 @@ def sd_transform(sd):
k_str = ''.join(k)
if 'SelfAttention' in k_str:
new_key = list(k)
new_key = [
i if i != 'SelfAttention_1' else 'MultiHeadDotProductAttention_0'
for i in new_key
]
if 'SelfAttention_0' in k_str:
if new_key[-2] == 'Dense_0':
# qkv
Expand All @@ -77,6 +73,13 @@ def sd_transform(sd):
# out
out[(*new_key[:-2], 'out', new_key[-1])] = sd[k]
pass
out = {
tuple(
k.replace('SelfAttention', 'MultiHeadDotProductAttention')
for k in key): value
for key,
value in out.items()
}
elif 'Dense' in k_str:
new_key = (*k[:2], 'MlpBlock_0', *k[2:])
out[new_key] = sd[k]
Expand Down
11 changes: 7 additions & 4 deletions tests/modeldiffs/wmt_glu_tanh/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,6 @@ def sd_transform(sd):
k_str = ''.join(k)
if 'SelfAttention' in k_str:
new_key = list(k)
new_key = [
i if i != 'SelfAttention_1' else 'MultiHeadDotProductAttention_0'
for i in new_key
]
if 'SelfAttention_0' in k_str:
if new_key[-2] == 'Dense_0':
# qkv
Expand All @@ -77,6 +73,13 @@ def sd_transform(sd):
# out
out[(*new_key[:-2], 'out', new_key[-1])] = sd[k]
pass
out = {
tuple(
k.replace('SelfAttention', 'MultiHeadDotProductAttention')
for k in key): value
for key,
value in out.items()
}
elif 'Dense' in k_str:
new_key = (*k[:2], 'MlpBlock_0', *k[2:])
out[new_key] = sd[k]
Expand Down
11 changes: 7 additions & 4 deletions tests/modeldiffs/wmt_post_ln/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,6 @@ def sd_transform(sd):
k_str = ''.join(k)
if 'SelfAttention' in k_str:
new_key = list(k)
new_key = [
i if i != 'SelfAttention_1' else 'MultiHeadDotProductAttention_0'
for i in new_key
]
if 'SelfAttention_0' in k_str:
if new_key[-2] == 'Dense_0':
# qkv
Expand All @@ -77,6 +73,13 @@ def sd_transform(sd):
# out
out[(*new_key[:-2], 'out', new_key[-1])] = sd[k]
pass
out = {
tuple(
k.replace('SelfAttention', 'MultiHeadDotProductAttention')
for k in key): value
for key,
value in out.items()
}
elif 'Dense' in k_str:
new_key = (*k[:2], 'MlpBlock_0', *k[2:])
out[new_key] = sd[k]
Expand Down

0 comments on commit 41ca68d

Please sign in to comment.