From 78015b96aad48929b5a85613ef7fed27ff0c6af8 Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Thu, 24 Oct 2024 11:06:18 -0700 Subject: [PATCH] Enhanced type narrowing logic for "x == " type guard pattern to handle the case where `x` is a type variable with a literal upper bound or value constraints that are literals. This addresses #9300. --- .../src/analyzer/typeGuards.ts | 2 +- .../tests/samples/typeNarrowingLiteral1.py | 26 ++++++++++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/packages/pyright-internal/src/analyzer/typeGuards.ts b/packages/pyright-internal/src/analyzer/typeGuards.ts index 8530284a5a12..3db90d2c5a45 100644 --- a/packages/pyright-internal/src/analyzer/typeGuards.ts +++ b/packages/pyright-internal/src/analyzer/typeGuards.ts @@ -2468,7 +2468,7 @@ function narrowTypeForLiteralComparison( isPositiveTest: boolean, isIsOperator: boolean ): Type { - return mapSubtypes(referenceType, (subtype) => { + return evaluator.mapSubtypesExpandTypeVars(referenceType, /* options */ undefined, (subtype) => { subtype = evaluator.makeTopLevelTypeVarsConcrete(subtype); if (isAnyOrUnknown(subtype)) { diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowingLiteral1.py b/packages/pyright-internal/src/tests/samples/typeNarrowingLiteral1.py index b7733d5c7e7a..fffe7251350e 100644 --- a/packages/pyright-internal/src/tests/samples/typeNarrowingLiteral1.py +++ b/packages/pyright-internal/src/tests/samples/typeNarrowingLiteral1.py @@ -1,7 +1,7 @@ # This sample tests the type analyzer's type narrowing # logic for literals. -from typing import Literal, Union +from typing import Literal, TypeVar, Union def func1(p1: Literal["a", "b", "c"]): @@ -28,3 +28,27 @@ def func2(p1: Literal[1, 4, 7]): def func3(a: Union[int, None]): if a == 1 or a == 2: reveal_type(a, expected_text="Literal[1, 2]") + + +T = TypeVar("T", bound=Literal["a", "b"]) + + +def func4(x: T) -> T: + if x == "a": + reveal_type(x, expected_text="Literal['a']") + return x + else: + reveal_type(x, expected_text="Literal['b']") + return x + + +S = TypeVar("S", Literal["a"], Literal["b"]) + + +def func5(x: S) -> S: + if x == "a": + reveal_type(x, expected_text="Literal['a']") + return x + else: + reveal_type(x, expected_text="Literal['b']") + return x