From 32b860e319813f2bfc2499365b714da133c289d6 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Tue, 7 Jan 2025 02:53:47 +0100 Subject: [PATCH] [stubgen] Improve self annotations (#18420) Print annotations for self variables if given. Aside from the most common ones for `str`, `int`, `bool` etc. those were previously inferred as `Incomplete`. --- mypy/stubgen.py | 10 +++++----- test-data/unit/stubgen.test | 11 +++++++++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index ca1fda27a976..27d868ed2624 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -648,11 +648,11 @@ def visit_func_def(self, o: FuncDef) -> None: self.add("\n") if not self.is_top_level(): self_inits = find_self_initializers(o) - for init, value in self_inits: + for init, value, annotation in self_inits: if init in self.method_names: # Can't have both an attribute and a method/property with the same name. continue - init_code = self.get_init(init, value) + init_code = self.get_init(init, value, annotation) if init_code: self.add(init_code) @@ -1414,7 +1414,7 @@ def find_method_names(defs: list[Statement]) -> set[str]: class SelfTraverser(mypy.traverser.TraverserVisitor): def __init__(self) -> None: - self.results: list[tuple[str, Expression]] = [] + self.results: list[tuple[str, Expression, Type | None]] = [] def visit_assignment_stmt(self, o: AssignmentStmt) -> None: lvalue = o.lvalues[0] @@ -1423,10 +1423,10 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: and isinstance(lvalue.expr, NameExpr) and lvalue.expr.name == "self" ): - self.results.append((lvalue.name, o.rvalue)) + self.results.append((lvalue.name, o.rvalue, o.unanalyzed_type)) -def find_self_initializers(fdef: FuncBase) -> list[tuple[str, Expression]]: +def find_self_initializers(fdef: FuncBase) -> list[tuple[str, Expression, Type | None]]: """Find attribute initializers in a method. Return a list of pairs (attribute name, r.h.s. expression). diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 0801d9a27011..9cfe301a9d0b 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -238,13 +238,24 @@ class C: def __init__(self, x: str) -> None: ... [case testSelfAssignment] +from mod import A +from typing import Any, Dict, Union class C: def __init__(self): + self.a: A = A() self.x = 1 x.y = 2 + self.y: Dict[str, Any] = {} + self.z: Union[int, str, bool, None] = None [out] +from mod import A +from typing import Any + class C: + a: A x: int + y: dict[str, Any] + z: int | str | bool | None def __init__(self) -> None: ... [case testSelfAndClassBodyAssignment]