diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 806022687..8d0ee8411 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -49,10 +49,6 @@ def sd_transform(sd): k_str = ''.join(k) if 'SelfAttention' in k_str: new_key = list(k) - new_key = [ - i if i != 'SelfAttention_1' else 'MultiHeadDotProductAttention_0' - for i in new_key - ] if 'SelfAttention_0' in k_str: if new_key[-2] == 'Dense_0': # qkv @@ -77,6 +73,13 @@ def sd_transform(sd): # out out[(*new_key[:-2], 'out', new_key[-1])] = sd[k] pass + out = { + tuple( + k.replace('SelfAttention', 'MultiHeadDotProductAttention') + for k in key): value + for key, + value in out.items() + } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) out[new_key] = sd[k] diff --git a/tests/modeldiffs/wmt_attention_temp/compare.py b/tests/modeldiffs/wmt_attention_temp/compare.py index fbe52ee05..b50abd3ca 100644 --- a/tests/modeldiffs/wmt_attention_temp/compare.py +++ b/tests/modeldiffs/wmt_attention_temp/compare.py @@ -49,10 +49,6 @@ def sd_transform(sd): k_str = ''.join(k) if 'SelfAttention' in k_str: new_key = list(k) - new_key = [ - i if i != 'SelfAttention_1' else 'MultiHeadDotProductAttention_0' - for i in new_key - ] if 'SelfAttention_0' in k_str: if new_key[-2] == 'Dense_0': # qkv @@ -77,6 +73,13 @@ def sd_transform(sd): # out out[(*new_key[:-2], 'out', new_key[-1])] = sd[k] pass + out = { + tuple( + k.replace('SelfAttention', 'MultiHeadDotProductAttention') + for k in key): value + for key, + value in out.items() + } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) out[new_key] = sd[k] diff --git a/tests/modeldiffs/wmt_glu_tanh/compare.py b/tests/modeldiffs/wmt_glu_tanh/compare.py index dfe5364b2..1322ad0a0 100644 --- a/tests/modeldiffs/wmt_glu_tanh/compare.py +++ b/tests/modeldiffs/wmt_glu_tanh/compare.py @@ -49,10 +49,6 @@ def sd_transform(sd): k_str = ''.join(k) if 'SelfAttention' in k_str: new_key = list(k) - new_key = [ - i if i != 'SelfAttention_1' else 'MultiHeadDotProductAttention_0' - for i in new_key - ] if 'SelfAttention_0' in k_str: if new_key[-2] == 'Dense_0': # qkv @@ -77,6 +73,13 @@ def sd_transform(sd): # out out[(*new_key[:-2], 'out', new_key[-1])] = sd[k] pass + out = { + tuple( + k.replace('SelfAttention', 'MultiHeadDotProductAttention') + for k in key): value + for key, + value in out.items() + } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) out[new_key] = sd[k] diff --git a/tests/modeldiffs/wmt_post_ln/compare.py b/tests/modeldiffs/wmt_post_ln/compare.py index a0dae6791..bfd701736 100644 --- a/tests/modeldiffs/wmt_post_ln/compare.py +++ b/tests/modeldiffs/wmt_post_ln/compare.py @@ -49,10 +49,6 @@ def sd_transform(sd): k_str = ''.join(k) if 'SelfAttention' in k_str: new_key = list(k) - new_key = [ - i if i != 'SelfAttention_1' else 'MultiHeadDotProductAttention_0' - for i in new_key - ] if 'SelfAttention_0' in k_str: if new_key[-2] == 'Dense_0': # qkv @@ -77,6 +73,13 @@ def sd_transform(sd): # out out[(*new_key[:-2], 'out', new_key[-1])] = sd[k] pass + out = { + tuple( + k.replace('SelfAttention', 'MultiHeadDotProductAttention') + for k in key): value + for key, + value in out.items() + } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) out[new_key] = sd[k]