diff --git a/packages/pyright-internal/src/analyzer/binder.ts b/packages/pyright-internal/src/analyzer/binder.ts index 4825323a6f3e..6c2a81c029a5 100644 --- a/packages/pyright-internal/src/analyzer/binder.ts +++ b/packages/pyright-internal/src/analyzer/binder.ts @@ -201,6 +201,10 @@ export class Binder extends ParseTreeWalker { // and require code flow analysis to resolve. private _currentScopeCodeFlowExpressions: Set | undefined; + // If we're actively binding a match statement, this is the current + // match expression. + private _currentMatchSubjExpr: ExpressionNode | undefined; + // Aliases of "typing" and "typing_extensions". private _typingImportAliases: string[] = []; @@ -2275,10 +2279,20 @@ export class Binder extends ParseTreeWalker { this._currentFlowNode = this._finishFlowLabel(preGuardLabel); + // Note the active match subject expression prior to binding + // the pattern. If the pattern involves any targets that overwrite + // the subject expression, this will be set to undefined. + this._currentMatchSubjExpr = node.d.expr; + // Bind the pattern. this.walk(caseStatement.d.pattern); - this._createFlowNarrowForPattern(node.d.expr, caseStatement); + // If the pattern involves targets that overwrite the subject + // expression, skip creating a flow node for narrowing the subject. + if (this._currentMatchSubjExpr) { + this._createFlowNarrowForPattern(node.d.expr, caseStatement); + this._currentMatchSubjExpr = undefined; + } // Apply the guard expression. if (caseStatement.d.guardExpr) { @@ -2465,6 +2479,16 @@ export class Binder extends ParseTreeWalker { const symbol = this._bindNameToScope(this._currentScope, target); this._createAssignmentTargetFlowNodes(target, /* walkTargets */ false, /* unbound */ false); + // See if the target overwrites all or a portion of the subject expression. + if (this._currentMatchSubjExpr) { + if ( + ParseTreeUtils.isMatchingExpression(target, this._currentMatchSubjExpr) || + ParseTreeUtils.isPartialMatchingExpression(target, this._currentMatchSubjExpr) + ) { + this._currentMatchSubjExpr = undefined; + } + } + if (symbol) { const declaration: VariableDeclaration = { type: DeclarationType.Variable, diff --git a/packages/pyright-internal/src/analyzer/patternMatching.ts b/packages/pyright-internal/src/analyzer/patternMatching.ts index 3dd371077daf..ceaa1eea6c06 100644 --- a/packages/pyright-internal/src/analyzer/patternMatching.ts +++ b/packages/pyright-internal/src/analyzer/patternMatching.ts @@ -859,7 +859,8 @@ function narrowTypeBasedOnClassPattern( LocAddendum.typeNotClass().format({ type: evaluator.printType(exprType) }), pattern.d.className ); - return NeverType.createNever(); + + return isPositiveTest ? UnknownType.create() : type; } else if (isInstantiableClass(exprType)) { if (ClassType.isProtocolClass(exprType) && !ClassType.isRuntimeCheckable(exprType)) { evaluator.addDiagnostic( @@ -867,12 +868,16 @@ function narrowTypeBasedOnClassPattern( LocAddendum.protocolRequiresRuntimeCheckable(), pattern.d.className ); + + return isPositiveTest ? UnknownType.create() : type; } else if (ClassType.isTypedDictClass(exprType)) { evaluator.addDiagnostic( DiagnosticRule.reportGeneralTypeIssues, LocMessage.typedDictInClassPattern(), pattern.d.className ); + + return isPositiveTest ? UnknownType.create() : type; } } diff --git a/packages/pyright-internal/src/tests/samples/matchClass7.py b/packages/pyright-internal/src/tests/samples/matchClass7.py new file mode 100644 index 000000000000..c2f4075650d9 --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/matchClass7.py @@ -0,0 +1,34 @@ +# This sample tests the case where a class pattern overwrites the subject +# expression. + +from dataclasses import dataclass + + +@dataclass +class DC1: + val: str + + +def func1(val: DC1): + result = val + + match result: + case DC1(result): + reveal_type(result, expected_text="str") + + +@dataclass +class DC2: + val: DC1 + + +def func2(val: DC2): + result = val + + match result.val: + case DC1(result): + reveal_type(result, expected_text="str") + + # This should generate an error because result.val + # is no longer valid at this point. + print(result.val) diff --git a/packages/pyright-internal/src/tests/typeEvaluator6.test.ts b/packages/pyright-internal/src/tests/typeEvaluator6.test.ts index 1e89799dd666..b5c8c98db9ca 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator6.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator6.test.ts @@ -520,6 +520,14 @@ test('MatchClass6', () => { TestUtils.validateResults(analysisResults, 0); }); +test('MatchClass7', () => { + const configOptions = new ConfigOptions(Uri.empty()); + + configOptions.defaultPythonVersion = pythonVersion3_10; + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['matchClass7.py'], configOptions); + TestUtils.validateResults(analysisResults, 1); +}); + test('MatchValue1', () => { const configOptions = new ConfigOptions(Uri.empty());