From 211ae6224376df5e715758d8d424ab7fc401fc65 Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Tue, 29 Oct 2024 15:32:32 -0700 Subject: [PATCH] Fixed a bug that causes a false positive error when a class uses `type(Protocol)` as a base class. This addresses #9217. This fix involves a change to the internal isSameGenericClass method, which was previously too permissive. This change required dozens of downstream changes, and it has a high risk of regression. --- .../pyright-internal/src/analyzer/checker.ts | 4 +- .../src/analyzer/codeFlowEngine.ts | 4 +- .../src/analyzer/constructors.ts | 6 ++- .../pyright-internal/src/analyzer/enums.ts | 6 ++- .../src/analyzer/patternMatching.ts | 2 +- .../src/analyzer/typeEvaluator.ts | 37 ++++++++++++++----- .../src/analyzer/typeGuards.ts | 19 +++++++--- .../src/analyzer/typePrinter.ts | 8 ++++ .../src/analyzer/typeUtils.ts | 6 ++- .../pyright-internal/src/analyzer/types.ts | 9 ++--- .../src/languageService/completionProvider.ts | 5 ++- 11 files changed, 76 insertions(+), 30 deletions(-) diff --git a/packages/pyright-internal/src/analyzer/checker.ts b/packages/pyright-internal/src/analyzer/checker.ts index 83c81d8f6e27..e3bdba5c1458 100644 --- a/packages/pyright-internal/src/analyzer/checker.ts +++ b/packages/pyright-internal/src/analyzer/checker.ts @@ -3846,7 +3846,7 @@ export class Checker extends ParseTreeWalker { if (isInstantiableClass(filterType)) { this._validateUnsafeProtocolOverlap( node.d.args[0].d.valueExpr, - convertToInstance(filterType), + ClassType.cloneAsInstance(filterType), isInstanceCheck ? arg0Type : convertToInstance(arg0Type) ); } @@ -4898,7 +4898,7 @@ export class Checker extends ParseTreeWalker { if ( !symbolType || !isClassInstance(symbolType) || - !ClassType.isSameGenericClass(symbolType, classType) || + !ClassType.isSameGenericClass(symbolType, ClassType.cloneAsInstance(classType)) || !(symbolType.priv.literalValue instanceof EnumLiteral) ) { return; diff --git a/packages/pyright-internal/src/analyzer/codeFlowEngine.ts b/packages/pyright-internal/src/analyzer/codeFlowEngine.ts index 9d01a3460261..017225fb4858 100644 --- a/packages/pyright-internal/src/analyzer/codeFlowEngine.ts +++ b/packages/pyright-internal/src/analyzer/codeFlowEngine.ts @@ -1583,7 +1583,7 @@ export function getCodeFlowEngine( ); return priorRemainingConstraints.filter((subtype) => - ClassType.isSameGenericClass(subtype, classType) + ClassType.isSameGenericClass(subtype, ClassType.cloneAsInstance(classType)) ); } } @@ -1632,7 +1632,7 @@ export function getCodeFlowEngine( if (isInstantiableClass(arg1Type)) { return priorRemainingConstraints.filter((subtype) => { - if (ClassType.isSameGenericClass(subtype, arg1Type)) { + if (ClassType.isSameGenericClass(subtype, ClassType.cloneAsInstance(arg1Type))) { return isPositiveTest; } else { return !isPositiveTest; diff --git a/packages/pyright-internal/src/analyzer/constructors.ts b/packages/pyright-internal/src/analyzer/constructors.ts index 3b06cff783fe..effdb91f1e4b 100644 --- a/packages/pyright-internal/src/analyzer/constructors.ts +++ b/packages/pyright-internal/src/analyzer/constructors.ts @@ -1068,7 +1068,11 @@ function shouldSkipInitEvaluation(evaluator: TypeEvaluator, classType: ClassType if (isClassInstance(subtype)) { const inheritanceChain: InheritanceChain = []; - const isDerivedFrom = ClassType.isDerivedFrom(subtype, classType, inheritanceChain); + const isDerivedFrom = ClassType.isDerivedFrom( + ClassType.cloneAsInstantiable(subtype), + classType, + inheritanceChain + ); if (!isDerivedFrom) { skipInitCheck = true; diff --git a/packages/pyright-internal/src/analyzer/enums.ts b/packages/pyright-internal/src/analyzer/enums.ts index 683e2990a0d1..d5b325e6b2a5 100644 --- a/packages/pyright-internal/src/analyzer/enums.ts +++ b/packages/pyright-internal/src/analyzer/enums.ts @@ -56,7 +56,11 @@ export function isEnumClassWithMembers(evaluator: TypeEvaluator, classType: Clas ClassType.getSymbolTable(classType).forEach((symbol, name) => { const symbolType = transformTypeForEnumMember(evaluator, classType, name); - if (symbolType && isClassInstance(symbolType) && ClassType.isSameGenericClass(symbolType, classType)) { + if ( + symbolType && + isClassInstance(symbolType) && + ClassType.isSameGenericClass(symbolType, ClassType.cloneAsInstance(classType)) + ) { definesMember = true; } }); diff --git a/packages/pyright-internal/src/analyzer/patternMatching.ts b/packages/pyright-internal/src/analyzer/patternMatching.ts index ccc4ebeafd5c..3dd371077daf 100644 --- a/packages/pyright-internal/src/analyzer/patternMatching.ts +++ b/packages/pyright-internal/src/analyzer/patternMatching.ts @@ -760,7 +760,7 @@ function narrowTypeBasedOnClassPattern( classType = ClassType.specialize(classType, /* typeArgs */ undefined); } - const classInstance = convertToInstance(classType); + const classInstance = ClassType.cloneAsInstance(classType); const isPatternMetaclass = isMetaclassInstance(classInstance); return evaluator.mapSubtypesExpandTypeVars( diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index b41f7478fc16..f78ae9c697b0 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -3584,7 +3584,12 @@ export function createTypeEvaluator( enclosingClass = classTypeResults.classType; if (isClassInstance(baseType)) { - if (ClassType.isSameGenericClass(baseType, classTypeResults.classType)) { + if ( + ClassType.isSameGenericClass( + ClassType.cloneAsInstantiable(baseType), + classTypeResults.classType + ) + ) { assignTypeToMemberVariable(target, typeResult, /* isInstanceMember */ true, srcExpr); } } else if (isInstantiableClass(baseType)) { @@ -5603,7 +5608,7 @@ export function createTypeEvaluator( // Is this an attempt to delete or overwrite an enum member? if ( isClassInstance(enumMemberResult.type) && - ClassType.isSameGenericClass(enumMemberResult.type, baseType) && + ClassType.isSameGenericClass(enumMemberResult.type, ClassType.cloneAsInstance(baseType)) && enumMemberResult.type.priv.literalValue !== undefined ) { const diagMessage = @@ -5995,7 +6000,10 @@ export function createTypeEvaluator( if ( containingClassType && isInstantiableClass(containingClassType) && - ClassType.isSameGenericClass(containingClassType, classType) + ClassType.isSameGenericClass( + isAccessedThroughObject ? ClassType.cloneAsInstance(containingClassType) : containingClassType, + classType + ) ) { type = getDeclaredTypeOfSymbol(memberInfo.symbol)?.type; if (type && isInstantiableClass(memberInfo.classType)) { @@ -6069,7 +6077,10 @@ export function createTypeEvaluator( if ( errorNode && isInstantiableClass(memberInfo.classType) && - ClassType.isSameGenericClass(memberInfo.classType, classType) + ClassType.isSameGenericClass( + memberInfo.classType, + isAccessedThroughObject ? ClassType.cloneAsInstantiable(classType) : classType + ) ) { setSymbolAccessed(AnalyzerNodeInfo.getFileInfo(errorNode), memberInfo.symbol, errorNode); } @@ -8837,7 +8848,10 @@ export function createTypeEvaluator( bindToType && ClassType.isProtocolClass(bindToType) && effectiveTargetClass && - !ClassType.isSameGenericClass(bindToType, effectiveTargetClass) + !ClassType.isSameGenericClass( + TypeBase.isInstance(bindToType) ? ClassType.cloneAsInstantiable(bindToType) : bindToType, + effectiveTargetClass + ) ) { isProtocolClass = true; effectiveTargetClass = undefined; @@ -8909,7 +8923,12 @@ export function createTypeEvaluator( if (bindToType) { let nextBaseClassType: Type | undefined; - if (ClassType.isSameGenericClass(bindToType, concreteTargetClassType)) { + if ( + ClassType.isSameGenericClass( + TypeBase.isInstance(bindToType) ? ClassType.cloneAsInstantiable(bindToType) : bindToType, + concreteTargetClassType + ) + ) { if (bindToType.shared.baseClasses.length > 0) { nextBaseClassType = bindToType.shared.baseClasses[0]; } @@ -13475,7 +13494,7 @@ export function createTypeEvaluator( if (isNoneInstance(subtype)) { if (objectClass && isInstantiableClass(objectClass)) { // Use 'object' for 'None'. - return handleSubtype(convertToInstance(objectClass)); + return handleSubtype(ClassType.cloneAsInstance(objectClass)); } } @@ -24344,8 +24363,8 @@ export function createTypeEvaluator( if (destMetaclass && isInstantiableClass(destMetaclass)) { if ( assignClass( - ClassType.cloneAsInstance(destMetaclass), - expandedSrcType, + destMetaclass, + ClassType.cloneAsInstantiable(expandedSrcType), diag, constraints, flags, diff --git a/packages/pyright-internal/src/analyzer/typeGuards.ts b/packages/pyright-internal/src/analyzer/typeGuards.ts index 3db90d2c5a45..2f08cdf8f618 100644 --- a/packages/pyright-internal/src/analyzer/typeGuards.ts +++ b/packages/pyright-internal/src/analyzer/typeGuards.ts @@ -1348,7 +1348,7 @@ function narrowTypeForInstance( // any metaclass, but we specifically want to treat type as the class // type[object] in this case. if (ClassType.isBuiltIn(filterMetaclass, 'type') && !filterMetaclass.priv.isTypeArgExplicit) { - if (!ClassType.isBuiltIn(metaclassType, 'type')) { + if (!isClass(metaclassType) || !ClassType.isBuiltIn(metaclassType, 'type')) { isMetaclassOverlap = false; } } @@ -1435,7 +1435,14 @@ function narrowTypeForInstance( } if (filterIsSubclass && !ClassType.isSameGenericClass(runtimeVarType, concreteFilterType)) { - isClassRelationshipIndeterminate = true; + // If the runtime variable type is a type[T], handle a filter + // of 'type' as a special case. + if ( + !ClassType.isBuiltIn(concreteFilterType, 'type') || + TypeBase.getInstantiableDepth(runtimeVarType) === 0 + ) { + isClassRelationshipIndeterminate = true; + } } } @@ -1479,8 +1486,8 @@ function narrowTypeForInstance( if ( addConstraintsForExpectedType( evaluator, - convertToInstance(unspecializedFilterType), - convertToInstance(concreteVarType), + ClassType.cloneAsInstance(unspecializedFilterType), + ClassType.cloneAsInstance(concreteVarType), constraints, /* liveTypeVarScopes */ undefined, errorNode.start @@ -1664,7 +1671,7 @@ function narrowTypeForInstance( const isFilterTypeCallbackProtocol = (filterType: Type) => { return ( isInstantiableClass(filterType) && - evaluator.getCallbackProtocolType(convertToInstance(filterType)) !== undefined + evaluator.getCallbackProtocolType(ClassType.cloneAsInstance(filterType)) !== undefined ); }; @@ -2324,7 +2331,7 @@ function narrowTypeForTypeIs(evaluator: TypeEvaluator, type: Type, classType: Cl const matches = ClassType.isDerivedFrom(classType, ClassType.cloneAsInstantiable(subtype)); if (isPositiveTest) { if (matches) { - if (ClassType.isSameGenericClass(subtype, classType)) { + if (ClassType.isSameGenericClass(ClassType.cloneAsInstantiable(subtype), classType)) { return addConditionToType(subtype, classType.props?.condition); } diff --git a/packages/pyright-internal/src/analyzer/typePrinter.ts b/packages/pyright-internal/src/analyzer/typePrinter.ts index 69861afd0252..6d924c6de2a8 100644 --- a/packages/pyright-internal/src/analyzer/typePrinter.ts +++ b/packages/pyright-internal/src/analyzer/typePrinter.ts @@ -1466,6 +1466,14 @@ class UniqueNameMap { } if (isClass(type1) && isClass(type2)) { + while (TypeBase.isInstantiable(type1)) { + type1 = ClassType.cloneAsInstance(type1); + } + + while (TypeBase.isInstantiable(type2)) { + type2 = ClassType.cloneAsInstance(type2); + } + return ClassType.isSameGenericClass(type1, type2); } diff --git a/packages/pyright-internal/src/analyzer/typeUtils.ts b/packages/pyright-internal/src/analyzer/typeUtils.ts index a6c742d481c5..16792548327c 100644 --- a/packages/pyright-internal/src/analyzer/typeUtils.ts +++ b/packages/pyright-internal/src/analyzer/typeUtils.ts @@ -1839,7 +1839,10 @@ export function* getClassMemberIterator( if ( memberName === '__call__' && classType.priv.partialCallType && - ClassType.isSameGenericClass(classType, specializedMroClass) + ClassType.isSameGenericClass( + TypeBase.isInstance(classType) ? ClassType.cloneAsInstantiable(classType) : classType, + specializedMroClass + ) ) { symbol = Symbol.createWithType(SymbolFlags.ClassMember, classType.priv.partialCallType); } @@ -2298,7 +2301,6 @@ export function isEffectivelyInstantiable(type: Type, options?: IsInstantiableOp return false; } -export function convertToInstance(type: ClassType, includeSubclasses?: boolean): ClassType; export function convertToInstance(type: ParamSpecType, includeSubclasses?: boolean): ParamSpecType; export function convertToInstance(type: TypeVarTupleType, includeSubclasses?: boolean): TypeVarTupleType; export function convertToInstance(type: TypeVarType, includeSubclasses?: boolean): TypeVarType; diff --git a/packages/pyright-internal/src/analyzer/types.ts b/packages/pyright-internal/src/analyzer/types.ts index 10974235118a..c0c7bfbb8fd9 100644 --- a/packages/pyright-internal/src/analyzer/types.ts +++ b/packages/pyright-internal/src/analyzer/types.ts @@ -1324,13 +1324,12 @@ export namespace ClassType { return false; } - // Handle type[] specially. - if (TypeBase.getInstantiableDepth(classType) > 0) { - return TypeBase.isInstantiable(type2) || ClassType.isBuiltIn(type2, 'type'); + if (TypeBase.isInstance(classType) !== TypeBase.isInstance(type2)) { + return false; } - if (TypeBase.getInstantiableDepth(type2) > 0) { - return TypeBase.isInstantiable(classType) || ClassType.isBuiltIn(classType, 'type'); + if (TypeBase.getInstantiableDepth(classType) !== TypeBase.getInstantiableDepth(type2)) { + return false; } const class1Details = classType.shared; diff --git a/packages/pyright-internal/src/languageService/completionProvider.ts b/packages/pyright-internal/src/languageService/completionProvider.ts index 6f50c68b2866..2bf0b1dbe42c 100644 --- a/packages/pyright-internal/src/languageService/completionProvider.ts +++ b/packages/pyright-internal/src/languageService/completionProvider.ts @@ -3183,7 +3183,10 @@ export class CompletionProvider { return ( symbolType && isClassInstance(symbolType) && - ClassType.isSameGenericClass(symbolType, containingType) && + ClassType.isSameGenericClass( + symbolType, + TypeBase.isInstance(containingType) ? containingType : ClassType.cloneAsInstance(containingType) + ) && symbolType.priv.literalValue instanceof EnumLiteral ); }