diff --git a/src/indico_patcher/types.py b/src/indico_patcher/types.py index 4691868..d7b42de 100644 --- a/src/indico_patcher/types.py +++ b/src/indico_patcher/types.py @@ -20,7 +20,7 @@ PatchWrapper: TypeAlias = ClassWrapper | EnumWrapper # noqa: UP040 # Annotations for extra attributes in patched classes -class PatchedClass: +class PatchedClass(type): __patches__: list[type] __unpatched__: dict[str, dict[str, Any]] diff --git a/src/indico_patcher/util.py b/src/indico_patcher/util.py index df057c6..d7dbd63 100644 --- a/src/indico_patcher/util.py +++ b/src/indico_patcher/util.py @@ -189,11 +189,11 @@ def _store_unpatched(orig_class: PatchedClass, member_name: str, category: str) :param orig_class: The class to store the reference in :param member_name: The name of the member to store the reference for """ + orig_members = get_members(orig_class) # None can be a valid value for the member, so we need to check if the member is in the class dict - if member_name in orig_class.__dict__: + if member_name in orig_members: # TODO: Log warning if member is already patched - member = orig_class.__dict__[member_name] - orig_class.__unpatched__[category][member_name] = member + orig_class.__unpatched__[category][member_name] = orig_members[member_name] def _inject_super_proxy(func: FunctionType, orig_class: PatchedClass) -> FunctionType: diff --git a/tests/test_util.py b/tests/test_util.py index 2870261..c61d77f 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -294,6 +294,23 @@ def test_store_unpatched_member(member_name, category, Fool): assert Fool.__unpatched__[category][member_name] == Fool.__dict__[member_name] +@pytest.mark.parametrize(("member_name", "category"), [ + ("attr", "attributes"), + ("prop", "properties"), + ("hprop", "hybrid_properties"), + ("meth", "methods"), + ("cmeth", "classmethods"), + ("smeth", "staticmethods"), +]) +def test_store_unpatched_member_from_parent(member_name, category, Fool): + class Magician(Fool): + def spell(self): + pass + + _store_unpatched(Magician, member_name, category) + assert Fool.__unpatched__[category][member_name] == Fool.__dict__[member_name] + + # -- inject super proxy -------------------------------------------------------- def test_inject_super_proxy(Fool):