Skip to content

Commit

Permalink
Added special-case handling x in y narrowing logic for the case whe…
Browse files Browse the repository at this point in the history
…re `x` is a `dict` or `Mapping` and `y` is an iterable of `TypedDict`s. This addresses #6436.
  • Loading branch information
erictraut committed Nov 14, 2023
1 parent f547a08 commit 058d206
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 1 deletion.
9 changes: 9 additions & 0 deletions packages/pyright-internal/src/analyzer/typeGuards.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2051,6 +2051,15 @@ export function narrowTypeForContainerElementType(evaluator: TypeEvaluator, refe
return referenceType;
}

// Handle the special case where the reference type is a dict or Mapping and
// the element type is a TypedDict. In this case, we can't say whether there
// is a type overlap, so don't apply narrowing.
if (isClassInstance(referenceType) && ClassType.isBuiltIn(referenceType, ['dict', 'Mapping'])) {
if (isClassInstance(concreteElementType) && ClassType.isTypedDictClass(concreteElementType)) {
return concreteElementType;
}
}

if (evaluator.assignType(referenceType, concreteElementType)) {
return concreteElementType;
}
Expand Down
17 changes: 16 additions & 1 deletion packages/pyright-internal/src/tests/samples/typeNarrowingIn1.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This sample tests type narrowing for the "in" operator.

from typing import Literal
from typing import Literal, TypedDict
import random


Expand Down Expand Up @@ -138,3 +138,18 @@ def func10(x: Literal["A", "B"], y: tuple[Literal["A"], ...]):
reveal_type(x, expected_text="Literal['A']")
else:
reveal_type(x, expected_text="Literal['A', 'B']")


class TD1(TypedDict):
x: str


class TD2(TypedDict):
y: str


def func11(x: dict[str, str]):
if x in (TD1(x="a"), TD2(y="b")):
reveal_type(x, expected_text="TD1 | TD2")
else:
reveal_type(x, expected_text="dict[str, str]")
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ def func2(x: MyEnum):
reveal_type(x, expected_text="Literal[MyEnum.A, MyEnum.B]")
else:
reveal_type(x, expected_text="Literal[MyEnum.C]")

0 comments on commit 058d206

Please sign in to comment.