From 41ca68dd42ea4003b5b2fdaea071dfba9606b1df Mon Sep 17 00:00:00 2001 From: runame Date: Mon, 11 Dec 2023 21:14:53 +0100 Subject: [PATCH] Start fixing modeldiff tests --- tests/modeldiffs/wmt/compare.py | 11 +++++++---- tests/modeldiffs/wmt_attention_temp/compare.py | 11 +++++++---- tests/modeldiffs/wmt_glu_tanh/compare.py | 11 +++++++---- tests/modeldiffs/wmt_post_ln/compare.py | 11 +++++++---- 4 files changed, 28 insertions(+), 16 deletions(-) diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 806022687..8d0ee8411 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -49,10 +49,6 @@ def sd_transform(sd): k_str = ''.join(k) if 'SelfAttention' in k_str: new_key = list(k) - new_key = [ - i if i != 'SelfAttention_1' else 'MultiHeadDotProductAttention_0' - for i in new_key - ] if 'SelfAttention_0' in k_str: if new_key[-2] == 'Dense_0': # qkv @@ -77,6 +73,13 @@ def sd_transform(sd): # out out[(*new_key[:-2], 'out', new_key[-1])] = sd[k] pass + out = { + tuple( + k.replace('SelfAttention', 'MultiHeadDotProductAttention') + for k in key): value + for key, + value in out.items() + } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) out[new_key] = sd[k] diff --git a/tests/modeldiffs/wmt_attention_temp/compare.py b/tests/modeldiffs/wmt_attention_temp/compare.py index fbe52ee05..b50abd3ca 100644 --- a/tests/modeldiffs/wmt_attention_temp/compare.py +++ b/tests/modeldiffs/wmt_attention_temp/compare.py @@ -49,10 +49,6 @@ def sd_transform(sd): k_str = ''.join(k) if 'SelfAttention' in k_str: new_key = list(k) - new_key = [ - i if i != 'SelfAttention_1' else 'MultiHeadDotProductAttention_0' - for i in new_key - ] if 'SelfAttention_0' in k_str: if new_key[-2] == 'Dense_0': # qkv @@ -77,6 +73,13 @@ def sd_transform(sd): # out out[(*new_key[:-2], 'out', new_key[-1])] = sd[k] pass + out = { + tuple( + k.replace('SelfAttention', 'MultiHeadDotProductAttention') + for k in key): value + for key, + value in out.items() + } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) out[new_key] = sd[k] diff --git a/tests/modeldiffs/wmt_glu_tanh/compare.py b/tests/modeldiffs/wmt_glu_tanh/compare.py index dfe5364b2..1322ad0a0 100644 --- a/tests/modeldiffs/wmt_glu_tanh/compare.py +++ b/tests/modeldiffs/wmt_glu_tanh/compare.py @@ -49,10 +49,6 @@ def sd_transform(sd): k_str = ''.join(k) if 'SelfAttention' in k_str: new_key = list(k) - new_key = [ - i if i != 'SelfAttention_1' else 'MultiHeadDotProductAttention_0' - for i in new_key - ] if 'SelfAttention_0' in k_str: if new_key[-2] == 'Dense_0': # qkv @@ -77,6 +73,13 @@ def sd_transform(sd): # out out[(*new_key[:-2], 'out', new_key[-1])] = sd[k] pass + out = { + tuple( + k.replace('SelfAttention', 'MultiHeadDotProductAttention') + for k in key): value + for key, + value in out.items() + } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) out[new_key] = sd[k] diff --git a/tests/modeldiffs/wmt_post_ln/compare.py b/tests/modeldiffs/wmt_post_ln/compare.py index a0dae6791..bfd701736 100644 --- a/tests/modeldiffs/wmt_post_ln/compare.py +++ b/tests/modeldiffs/wmt_post_ln/compare.py @@ -49,10 +49,6 @@ def sd_transform(sd): k_str = ''.join(k) if 'SelfAttention' in k_str: new_key = list(k) - new_key = [ - i if i != 'SelfAttention_1' else 'MultiHeadDotProductAttention_0' - for i in new_key - ] if 'SelfAttention_0' in k_str: if new_key[-2] == 'Dense_0': # qkv @@ -77,6 +73,13 @@ def sd_transform(sd): # out out[(*new_key[:-2], 'out', new_key[-1])] = sd[k] pass + out = { + tuple( + k.replace('SelfAttention', 'MultiHeadDotProductAttention') + for k in key): value + for key, + value in out.items() + } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) out[new_key] = sd[k]