From 9be49b3b15cd26ce712ff286719dc7af61fa1ad5 Mon Sep 17 00:00:00 2001 From: Stanislav Terliakov <50529348+sterliakov@users.noreply.github.com> Date: Mon, 13 Jan 2025 23:08:41 +0100 Subject: [PATCH] Prevent crashing when `match` arms use name of existing callable (#18449) Fixes #16793. Fixes crash in #13666. Previously mypy considered that variables in match/case patterns must be Var's, causing a hard crash when a name of captured pattern clashes with a name of some existing function. This PR removes such assumption about Var and allows other nodes. --- mypy/checker.py | 19 +++++++---- test-data/unit/check-python310.test | 51 +++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 6 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index f6193a1273eb..79d178f3c644 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5402,17 +5402,21 @@ def _get_recursive_sub_patterns_map( return sub_patterns_map - def infer_variable_types_from_type_maps(self, type_maps: list[TypeMap]) -> dict[Var, Type]: - all_captures: dict[Var, list[tuple[NameExpr, Type]]] = defaultdict(list) + def infer_variable_types_from_type_maps( + self, type_maps: list[TypeMap] + ) -> dict[SymbolNode, Type]: + # Type maps may contain variables inherited from previous code which are not + # necessary `Var`s (e.g. a function defined earlier with the same name). + all_captures: dict[SymbolNode, list[tuple[NameExpr, Type]]] = defaultdict(list) for tm in type_maps: if tm is not None: for expr, typ in tm.items(): if isinstance(expr, NameExpr): node = expr.node - assert isinstance(node, Var) + assert node is not None all_captures[node].append((expr, typ)) - inferred_types: dict[Var, Type] = {} + inferred_types: dict[SymbolNode, Type] = {} for var, captures in all_captures.items(): already_exists = False types: list[Type] = [] @@ -5436,16 +5440,19 @@ def infer_variable_types_from_type_maps(self, type_maps: list[TypeMap]) -> dict[ new_type = UnionType.make_union(types) # Infer the union type at the first occurrence first_occurrence, _ = captures[0] + # If it didn't exist before ``match``, it's a Var. + assert isinstance(var, Var) inferred_types[var] = new_type self.infer_variable_type(var, first_occurrence, new_type, first_occurrence) return inferred_types - def remove_capture_conflicts(self, type_map: TypeMap, inferred_types: dict[Var, Type]) -> None: + def remove_capture_conflicts( + self, type_map: TypeMap, inferred_types: dict[SymbolNode, Type] + ) -> None: if type_map: for expr, typ in list(type_map.items()): if isinstance(expr, NameExpr): node = expr.node - assert isinstance(node, Var) if node not in inferred_types or not is_subtype(typ, inferred_types[node]): del type_map[expr] diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index d4af449fc7d7..9adb798c4ae7 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -2471,3 +2471,54 @@ def nested_in_dict(d: dict[str, Any]) -> int: return 0 [builtins fixtures/dict.pyi] + +[case testMatchRebindsOuterFunctionName] +# flags: --warn-unreachable +from typing_extensions import Literal + +def x() -> tuple[Literal["test"]]: ... + +match x(): + case (x,) if x == "test": # E: Incompatible types in capture pattern (pattern captures type "Literal['test']", variable has type "Callable[[], Tuple[Literal['test']]]") + reveal_type(x) # N: Revealed type is "def () -> Tuple[Literal['test']]" + case foo: + foo + +[builtins fixtures/dict.pyi] + +[case testMatchRebindsInnerFunctionName] +# flags: --warn-unreachable +class Some: + value: int | str + __match_args__ = ("value",) + +def fn1(x: Some | int | str) -> None: + match x: + case int(): + def value(): + return 1 + reveal_type(value) # N: Revealed type is "def () -> Any" + case str(): + def value(): + return 1 + reveal_type(value) # N: Revealed type is "def () -> Any" + case Some(value): # E: Incompatible types in capture pattern (pattern captures type "Union[int, str]", variable has type "Callable[[], Any]") + pass + +def fn2(x: Some | int | str) -> None: + match x: + case int(): + def value() -> str: + return "" + reveal_type(value) # N: Revealed type is "def () -> builtins.str" + case str(): + def value() -> int: # E: All conditional function variants must have identical signatures \ + # N: Original: \ + # N: def value() -> str \ + # N: Redefinition: \ + # N: def value() -> int + return 1 + reveal_type(value) # N: Revealed type is "def () -> builtins.str" + case Some(value): # E: Incompatible types in capture pattern (pattern captures type "Union[int, str]", variable has type "Callable[[], str]") + pass +[builtins fixtures/dict.pyi]