From 5bb5bde6e53cdd8ca83273e4bc82b3ff8fe91b67 Mon Sep 17 00:00:00 2001 From: Stephen Morton Date: Thu, 12 Dec 2024 23:54:02 -0800 Subject: [PATCH 1/4] stubtest: handle overloads with mixed pos-only params https://github.com/python/mypy/issues/17023 --- mypy/stubtest.py | 30 ++++++++++++++++++++++-------- mypy/test/teststubtest.py | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 8 deletions(-) diff --git a/mypy/stubtest.py b/mypy/stubtest.py index 36cd0a213d4d..233e08e76406 100644 --- a/mypy/stubtest.py +++ b/mypy/stubtest.py @@ -824,6 +824,21 @@ def from_overloadedfuncdef(stub: nodes.OverloadedFuncDef) -> Signature[nodes.Arg # For most dunder methods, just assume all args are positional-only assume_positional_only = is_dunder(stub.name, exclude_special=True) + is_arg_pos_only: defaultdict[str, set[bool]] = defaultdict(set) + for func in map(_resolve_funcitem_from_decorator, stub.items): + assert func is not None + args = maybe_strip_cls(stub.name, func.arguments) + for arg in args: + if ( + arg.variable.name.startswith("__") + or arg.pos_only + or assume_positional_only + or arg.variable.name.strip("_") == "self" + ): + is_arg_pos_only[arg.variable.name].add(True) + else: + is_arg_pos_only[arg.variable.name].add(False) + all_args: dict[str, list[tuple[nodes.Argument, int]]] = {} for func in map(_resolve_funcitem_from_decorator, stub.items): assert func is not None @@ -831,14 +846,13 @@ def from_overloadedfuncdef(stub: nodes.OverloadedFuncDef) -> Signature[nodes.Arg for index, arg in enumerate(args): # For positional-only args, we allow overloads to have different names for the same # argument. To accomplish this, we just make up a fake index-based name. - name = ( - f"__{index}" - if arg.variable.name.startswith("__") - or arg.pos_only - or assume_positional_only - or arg.variable.name.strip("_") == "self" - else arg.variable.name - ) + # We can only use the index-based name if the argument is always + # positional only. Sometimes overloads have an arg as positional-only + # in some but not all branches of the overload. + name = arg.variable.name + if is_arg_pos_only[name] == {True}: + name = f"__{index}" + all_args.setdefault(name, []).append((arg, index)) def get_position(arg_name: str) -> int: diff --git a/mypy/test/teststubtest.py b/mypy/test/teststubtest.py index fcbf07b4d371..0417d19c1879 100644 --- a/mypy/test/teststubtest.py +++ b/mypy/test/teststubtest.py @@ -12,7 +12,11 @@ from typing import Any, Callable, Iterator import mypy.stubtest +from mypy import build +from mypy.modulefinder import BuildSource +from mypy.options import Options from mypy.stubtest import parse_options, test_stubs +from mypy.test.config import test_temp_dir from mypy.test.data import root_dir @@ -144,6 +148,14 @@ def __invert__(self: _T) -> _T: pass """ +def build_helper(source: str) -> build.BuildResult: + return build.build( + sources=[BuildSource("main.pyi", None, textwrap.dedent(source))], + options=Options(), + alt_lib_path=test_temp_dir, + ) + + def run_stubtest_with_stderr( stub: str, runtime: str, options: list[str], config_file: str | None = None ) -> tuple[str, str]: @@ -801,6 +813,18 @@ def f2(self, *a) -> int: ... """, error=None, ) + yield Case( + stub=""" + @overload + def f(a: int) -> int: ... + @overload + def f(a: int, b: str, /) -> str: ... + """, + runtime=""" + def f(a, *args): ... + """, + error=None, + ) @collect_cases def test_property(self) -> Iterator[Case]: @@ -2577,6 +2601,21 @@ def test_builtin_signature_with_unrepresentable_default(self) -> None: == "def (self, sep = ..., bytes_per_sep = ...)" ) + def test_overload_signature(self) -> None: + # The same argument as both positional-only and pos-or-kw in + # different overloads used to produce incorrect signatures + source = """ + from typing import overload + @overload + def myfunction(arg: int) -> None: ... + @overload + def myfunction(arg: str, /) -> None: ... + """ + result = build_helper(source) + stub = result.files["__main__"].names["myfunction"].node + sig = mypy.stubtest.Signature.from_overloadedfuncdef(stub) + assert str(sig) == "def (arg: Union[builtins.int, builtins.str])" + def test_config_file(self) -> None: runtime = "temp = 5\n" stub = "from decimal import Decimal\ntemp: Decimal\n" From f432e6f73c5f4e1da2291c95fb2e5cab7e3ce50b Mon Sep 17 00:00:00 2001 From: Stephen Morton Date: Fri, 13 Dec 2024 00:10:29 -0800 Subject: [PATCH 2/4] clarify wording --- mypy/test/teststubtest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/test/teststubtest.py b/mypy/test/teststubtest.py index 0417d19c1879..01b3a5d84e35 100644 --- a/mypy/test/teststubtest.py +++ b/mypy/test/teststubtest.py @@ -2603,7 +2603,7 @@ def test_builtin_signature_with_unrepresentable_default(self) -> None: def test_overload_signature(self) -> None: # The same argument as both positional-only and pos-or-kw in - # different overloads used to produce incorrect signatures + # different overloads previously produced incorrect signatures source = """ from typing import overload @overload From 39656778a4ede9153043c669decafcc40a9bdcff Mon Sep 17 00:00:00 2001 From: Stephen Morton Date: Fri, 13 Dec 2024 00:13:46 -0800 Subject: [PATCH 3/4] type check --- mypy/test/teststubtest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mypy/test/teststubtest.py b/mypy/test/teststubtest.py index 01b3a5d84e35..8e6f68ddff21 100644 --- a/mypy/test/teststubtest.py +++ b/mypy/test/teststubtest.py @@ -2613,6 +2613,7 @@ def myfunction(arg: str, /) -> None: ... """ result = build_helper(source) stub = result.files["__main__"].names["myfunction"].node + assert stub is not None sig = mypy.stubtest.Signature.from_overloadedfuncdef(stub) assert str(sig) == "def (arg: Union[builtins.int, builtins.str])" From 3e9ca378cb15d0845fd32be7426fc6675158e720 Mon Sep 17 00:00:00 2001 From: Stephen Morton Date: Fri, 13 Dec 2024 00:23:03 -0800 Subject: [PATCH 4/4] type check --- mypy/test/teststubtest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mypy/test/teststubtest.py b/mypy/test/teststubtest.py index 8e6f68ddff21..12e806de110e 100644 --- a/mypy/test/teststubtest.py +++ b/mypy/test/teststubtest.py @@ -12,7 +12,7 @@ from typing import Any, Callable, Iterator import mypy.stubtest -from mypy import build +from mypy import build, nodes from mypy.modulefinder import BuildSource from mypy.options import Options from mypy.stubtest import parse_options, test_stubs @@ -2613,7 +2613,7 @@ def myfunction(arg: str, /) -> None: ... """ result = build_helper(source) stub = result.files["__main__"].names["myfunction"].node - assert stub is not None + assert isinstance(stub, nodes.OverloadedFuncDef) sig = mypy.stubtest.Signature.from_overloadedfuncdef(stub) assert str(sig) == "def (arg: Union[builtins.int, builtins.str])"