Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stubtest: handle overloads with mixed pos-only params #18287

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions mypy/stubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,21 +824,35 @@ def from_overloadedfuncdef(stub: nodes.OverloadedFuncDef) -> Signature[nodes.Arg
# For most dunder methods, just assume all args are positional-only
assume_positional_only = is_dunder(stub.name, exclude_special=True)

is_arg_pos_only: defaultdict[str, set[bool]] = defaultdict(set)
for func in map(_resolve_funcitem_from_decorator, stub.items):
assert func is not None
args = maybe_strip_cls(stub.name, func.arguments)
for arg in args:
if (
arg.variable.name.startswith("__")
or arg.pos_only
or assume_positional_only
or arg.variable.name.strip("_") == "self"
):
is_arg_pos_only[arg.variable.name].add(True)
else:
is_arg_pos_only[arg.variable.name].add(False)

all_args: dict[str, list[tuple[nodes.Argument, int]]] = {}
for func in map(_resolve_funcitem_from_decorator, stub.items):
assert func is not None
args = maybe_strip_cls(stub.name, func.arguments)
for index, arg in enumerate(args):
# For positional-only args, we allow overloads to have different names for the same
# argument. To accomplish this, we just make up a fake index-based name.
name = (
f"__{index}"
if arg.variable.name.startswith("__")
or arg.pos_only
or assume_positional_only
or arg.variable.name.strip("_") == "self"
else arg.variable.name
)
# We can only use the index-based name if the argument is always
# positional only. Sometimes overloads have an arg as positional-only
# in some but not all branches of the overload.
name = arg.variable.name
if is_arg_pos_only[name] == {True}:
name = f"__{index}"

all_args.setdefault(name, []).append((arg, index))

def get_position(arg_name: str) -> int:
Expand Down
40 changes: 40 additions & 0 deletions mypy/test/teststubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from typing import Any, Callable, Iterator

import mypy.stubtest
from mypy import build, nodes
from mypy.modulefinder import BuildSource
from mypy.options import Options
from mypy.stubtest import parse_options, test_stubs
from mypy.test.config import test_temp_dir
from mypy.test.data import root_dir


Expand Down Expand Up @@ -144,6 +148,14 @@ def __invert__(self: _T) -> _T: pass
"""


def build_helper(source: str) -> build.BuildResult:
return build.build(
sources=[BuildSource("main.pyi", None, textwrap.dedent(source))],
options=Options(),
alt_lib_path=test_temp_dir,
)


def run_stubtest_with_stderr(
stub: str, runtime: str, options: list[str], config_file: str | None = None
) -> tuple[str, str]:
Expand Down Expand Up @@ -801,6 +813,18 @@ def f2(self, *a) -> int: ...
""",
error=None,
)
yield Case(
stub="""
@overload
def f(a: int) -> int: ...
@overload
def f(a: int, b: str, /) -> str: ...
""",
runtime="""
def f(a, *args): ...
""",
error=None,
)

@collect_cases
def test_property(self) -> Iterator[Case]:
Expand Down Expand Up @@ -2577,6 +2601,22 @@ def test_builtin_signature_with_unrepresentable_default(self) -> None:
== "def (self, sep = ..., bytes_per_sep = ...)"
)

def test_overload_signature(self) -> None:
# The same argument as both positional-only and pos-or-kw in
# different overloads previously produced incorrect signatures
source = """
from typing import overload
@overload
def myfunction(arg: int) -> None: ...
@overload
def myfunction(arg: str, /) -> None: ...
"""
result = build_helper(source)
stub = result.files["__main__"].names["myfunction"].node
assert isinstance(stub, nodes.OverloadedFuncDef)
sig = mypy.stubtest.Signature.from_overloadedfuncdef(stub)
assert str(sig) == "def (arg: Union[builtins.int, builtins.str])"

def test_config_file(self) -> None:
runtime = "temp = 5\n"
stub = "from decimal import Decimal\ntemp: Decimal\n"
Expand Down
Loading