diff --git a/mypy/stubtest.py b/mypy/stubtest.py index e7cc24f33d18..6061e98bd7cd 100644 --- a/mypy/stubtest.py +++ b/mypy/stubtest.py @@ -826,7 +826,10 @@ def from_overloadedfuncdef(stub: nodes.OverloadedFuncDef) -> Signature[nodes.Arg # argument. To accomplish this, we just make up a fake index-based name. name = ( f"__{index}" - if arg.variable.name.startswith("__") or assume_positional_only + if arg.variable.name.startswith("__") + or arg.pos_only + or assume_positional_only + or arg.variable.name.strip("_") == "self" else arg.variable.name ) all_args.setdefault(name, []).append((arg, index)) @@ -870,6 +873,7 @@ def get_kind(arg_name: str) -> nodes.ArgKind: type_annotation=None, initializer=None, kind=get_kind(arg_name), + pos_only=all(arg.pos_only for arg, _ in all_args[arg_name]), ) if arg.kind.is_positional(): sig.pos.append(arg) @@ -905,6 +909,7 @@ def _verify_signature( if ( runtime_arg.kind != inspect.Parameter.POSITIONAL_ONLY and (stub_arg.pos_only or stub_arg.variable.name.startswith("__")) + and stub_arg.variable.name.strip("_") != "self" and not is_dunder(function_name, exclude_special=True) # noisy for dunder methods ): yield ( diff --git a/mypy/test/teststubtest.py b/mypy/test/teststubtest.py index 72b6f6620f83..c0bb2eb6f9da 100644 --- a/mypy/test/teststubtest.py +++ b/mypy/test/teststubtest.py @@ -210,7 +210,13 @@ def test(*args: Any, **kwargs: Any) -> None: ) actual_errors = set(output.splitlines()) - assert actual_errors == expected_errors, output + if actual_errors != expected_errors: + output = run_stubtest( + stub="\n\n".join(textwrap.dedent(c.stub.lstrip("\n")) for c in cases), + runtime="\n\n".join(textwrap.dedent(c.runtime.lstrip("\n")) for c in cases), + options=[], + ) + assert actual_errors == expected_errors, output return test @@ -660,6 +666,56 @@ def f6(self, x, /): pass """, error=None, ) + yield Case( + stub=""" + @overload + def f7(a: int, /) -> int: ... + @overload + def f7(b: str, /) -> str: ... + """, + runtime="def f7(x, /): pass", + error=None, + ) + yield Case( + stub=""" + @overload + def f8(a: int, c: int = 0, /) -> int: ... + @overload + def f8(b: str, d: int, /) -> str: ... + """, + runtime="def f8(x, y, /): pass", + error="f8", + ) + yield Case( + stub=""" + @overload + def f9(a: int, c: int = 0, /) -> int: ... + @overload + def f9(b: str, d: int, /) -> str: ... + """, + runtime="def f9(x, y=0, /): pass", + error=None, + ) + yield Case( + stub=""" + class Bar: + @overload + def f1(self) -> int: ... + @overload + def f1(self, a: int, /) -> int: ... + + @overload + def f2(self, a: int, /) -> int: ... + @overload + def f2(self, a: str, /) -> int: ... + """, + runtime=""" + class Bar: + def f1(self, *a) -> int: ... + def f2(self, *a) -> int: ... + """, + error=None, + ) @collect_cases def test_property(self) -> Iterator[Case]: