From f820bc4902c606b49ff15520724aed17c4dfae9a Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Thu, 14 Nov 2024 12:40:26 -0800 Subject: [PATCH] Changed inference logic for exception groups to more closely match the runtime. If a non-base exception is targeted, the inferred type is now `ExceptionGroup` rather than `BaseExceptionGroup`. This addresses #9466. --- .../src/analyzer/typeEvaluator.ts | 12 ++++++++++-- .../src/tests/samples/exceptionGroup1.py | 18 +++++++++++++++++- .../src/tests/typeEvaluator7.test.ts | 2 +- 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index c1229a3815ce..823f891fe892 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -19446,6 +19446,7 @@ export function createTypeEvaluator( const exceptionTypeResult = getTypeOfExpression(node.d.typeExpr!); const exceptionTypes = exceptionTypeResult.type; + let includesBaseException = false; function getExceptionType(exceptionType: Type, errorNode: ExpressionNode) { exceptionType = makeTopLevelTypeVarsConcrete(exceptionType); @@ -19455,6 +19456,9 @@ export function createTypeEvaluator( } if (isInstantiableClass(exceptionType)) { + if (ClassType.isBuiltIn(exceptionType, 'BaseException')) { + includesBaseException = true; + } return ClassType.cloneAsInstance(exceptionType); } @@ -19492,9 +19496,13 @@ export function createTypeEvaluator( return getExceptionType(subType, node.d.typeExpr!); }); - // If this is an except group, wrap the exception type in an BaseExceptionGroup. + // If this is an except group, wrap the exception type in an ExceptionGroup + // or BaseExceptionGroup depending on whether the target exception is + // a BaseException. if (node.d.isExceptGroup) { - targetType = getBuiltInObject(node, 'BaseExceptionGroup', [targetType]); + targetType = getBuiltInObject(node, includesBaseException ? 'BaseExceptionGroup' : 'ExceptionGroup', [ + targetType, + ]); } if (node.d.name) { diff --git a/packages/pyright-internal/src/tests/samples/exceptionGroup1.py b/packages/pyright-internal/src/tests/samples/exceptionGroup1.py index 4b286ccd9589..f6ca41412f0a 100644 --- a/packages/pyright-internal/src/tests/samples/exceptionGroup1.py +++ b/packages/pyright-internal/src/tests/samples/exceptionGroup1.py @@ -9,7 +9,7 @@ def func1(): # This should generate an error if using Python 3.10 or earlier. except* ValueError as e: - reveal_type(e, expected_text="BaseExceptionGroup[ValueError]") + reveal_type(e, expected_text="ExceptionGroup[ValueError]") pass # This should generate an error if using Python 3.10 or earlier. @@ -105,3 +105,19 @@ def inner(): # return is not allowed in an except* block. return + + +def func8(): + + try: + pass + + # This should generate an error if using Python 3.10 or earlier. + except* (ValueError, FloatingPointError) as e: + reveal_type(e, expected_text="ExceptionGroup[ValueError | FloatingPointError]") + pass + + # This should generate an error if using Python 3.10 or earlier. + except* BaseException as e: + reveal_type(e, expected_text="BaseExceptionGroup[BaseException]") + pass diff --git a/packages/pyright-internal/src/tests/typeEvaluator7.test.ts b/packages/pyright-internal/src/tests/typeEvaluator7.test.ts index 3dcef816cec5..ea9bc1cf380c 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator7.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator7.test.ts @@ -971,7 +971,7 @@ test('exceptionGroup1', () => { configOptions.defaultPythonVersion = pythonVersion3_10; const analysisResults1 = TestUtils.typeAnalyzeSampleFiles(['exceptionGroup1.py'], configOptions); - TestUtils.validateResults(analysisResults1, 28); + TestUtils.validateResults(analysisResults1, 34); configOptions.defaultPythonVersion = pythonVersion3_11; const analysisResults2 = TestUtils.typeAnalyzeSampleFiles(['exceptionGroup1.py'], configOptions);