diff --git a/mypy/stubtest.py b/mypy/stubtest.py index 36cd0a213d4d..233e08e76406 100644 --- a/mypy/stubtest.py +++ b/mypy/stubtest.py @@ -824,6 +824,21 @@ def from_overloadedfuncdef(stub: nodes.OverloadedFuncDef) -> Signature[nodes.Arg # For most dunder methods, just assume all args are positional-only assume_positional_only = is_dunder(stub.name, exclude_special=True) + is_arg_pos_only: defaultdict[str, set[bool]] = defaultdict(set) + for func in map(_resolve_funcitem_from_decorator, stub.items): + assert func is not None + args = maybe_strip_cls(stub.name, func.arguments) + for arg in args: + if ( + arg.variable.name.startswith("__") + or arg.pos_only + or assume_positional_only + or arg.variable.name.strip("_") == "self" + ): + is_arg_pos_only[arg.variable.name].add(True) + else: + is_arg_pos_only[arg.variable.name].add(False) + all_args: dict[str, list[tuple[nodes.Argument, int]]] = {} for func in map(_resolve_funcitem_from_decorator, stub.items): assert func is not None @@ -831,14 +846,13 @@ def from_overloadedfuncdef(stub: nodes.OverloadedFuncDef) -> Signature[nodes.Arg for index, arg in enumerate(args): # For positional-only args, we allow overloads to have different names for the same # argument. To accomplish this, we just make up a fake index-based name. - name = ( - f"__{index}" - if arg.variable.name.startswith("__") - or arg.pos_only - or assume_positional_only - or arg.variable.name.strip("_") == "self" - else arg.variable.name - ) + # We can only use the index-based name if the argument is always + # positional only. Sometimes overloads have an arg as positional-only + # in some but not all branches of the overload. + name = arg.variable.name + if is_arg_pos_only[name] == {True}: + name = f"__{index}" + all_args.setdefault(name, []).append((arg, index)) def get_position(arg_name: str) -> int: diff --git a/mypy/test/teststubtest.py b/mypy/test/teststubtest.py index fcbf07b4d371..12e806de110e 100644 --- a/mypy/test/teststubtest.py +++ b/mypy/test/teststubtest.py @@ -12,7 +12,11 @@ from typing import Any, Callable, Iterator import mypy.stubtest +from mypy import build, nodes +from mypy.modulefinder import BuildSource +from mypy.options import Options from mypy.stubtest import parse_options, test_stubs +from mypy.test.config import test_temp_dir from mypy.test.data import root_dir @@ -144,6 +148,14 @@ def __invert__(self: _T) -> _T: pass """ +def build_helper(source: str) -> build.BuildResult: + return build.build( + sources=[BuildSource("main.pyi", None, textwrap.dedent(source))], + options=Options(), + alt_lib_path=test_temp_dir, + ) + + def run_stubtest_with_stderr( stub: str, runtime: str, options: list[str], config_file: str | None = None ) -> tuple[str, str]: @@ -801,6 +813,18 @@ def f2(self, *a) -> int: ... """, error=None, ) + yield Case( + stub=""" + @overload + def f(a: int) -> int: ... + @overload + def f(a: int, b: str, /) -> str: ... + """, + runtime=""" + def f(a, *args): ... + """, + error=None, + ) @collect_cases def test_property(self) -> Iterator[Case]: @@ -2577,6 +2601,22 @@ def test_builtin_signature_with_unrepresentable_default(self) -> None: == "def (self, sep = ..., bytes_per_sep = ...)" ) + def test_overload_signature(self) -> None: + # The same argument as both positional-only and pos-or-kw in + # different overloads previously produced incorrect signatures + source = """ + from typing import overload + @overload + def myfunction(arg: int) -> None: ... + @overload + def myfunction(arg: str, /) -> None: ... + """ + result = build_helper(source) + stub = result.files["__main__"].names["myfunction"].node + assert isinstance(stub, nodes.OverloadedFuncDef) + sig = mypy.stubtest.Signature.from_overloadedfuncdef(stub) + assert str(sig) == "def (arg: Union[builtins.int, builtins.str])" + def test_config_file(self) -> None: runtime = "temp = 5\n" stub = "from decimal import Decimal\ntemp: Decimal\n"