Skip to content

Commit

Permalink
Improved pyright's enforcement of keyword arguments passed within a `…
Browse files Browse the repository at this point in the history
…class` statement when the class has no custom metaclass or `__init_subclass__` in its hierarchy. In this case, the `object.__init_subclass__` method applies, and it accepts no additional keyword arguments. Also improved the error reporting for `__init_subclass__` in general. This addresses #6403. (#6405)
  • Loading branch information
erictraut authored Nov 9, 2023
1 parent 9a56c39 commit de7625f
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 107 deletions.
214 changes: 121 additions & 93 deletions packages/pyright-internal/src/analyzer/typeEvaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16702,111 +16702,139 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
});

const errorNode = argList.length > 0 ? argList[0].node?.name ?? node.name : node.name;
const initSubclassMethodInfo = getTypeOfBoundMember(
errorNode,
classType,
'__init_subclass__',
/* usage */ undefined,
/* diag */ undefined,
MemberAccessFlags.SkipClassMembers |
MemberAccessFlags.SkipObjectBaseClass |
MemberAccessFlags.SkipOriginalClass |
MemberAccessFlags.SkipAttributeAccessOverride
);
let newMethodMember: ClassMember | undefined;

if (initSubclassMethodInfo) {
const initSubclassMethodType = initSubclassMethodInfo.type;

if (initSubclassMethodType) {
validateCallArguments(
errorNode,
argList,
{ type: initSubclassMethodType },
/* typeVarContext */ undefined,
/* skipUnknownArgCheck */ false,
makeInferenceContext(getNoneType())
);
}
} else if (classType.details.effectiveMetaclass && isClass(classType.details.effectiveMetaclass)) {
if (classType.details.effectiveMetaclass && isClass(classType.details.effectiveMetaclass)) {
// See if the metaclass has a `__new__` method that accepts keyword parameters.
const newMethodMember = lookUpClassMember(
newMethodMember = lookUpClassMember(
classType.details.effectiveMetaclass,
'__new__',
MemberAccessFlags.SkipTypeBaseClass
);
}

if (newMethodMember) {
const newMethodType = getTypeOfMember(newMethodMember);
if (isFunction(newMethodType)) {
const paramListDetails = getParameterListDetails(newMethodType);

if (paramListDetails.firstKeywordOnlyIndex !== undefined) {
// Build a map of the keyword-only parameters.
const paramMap = new Map<string, number>();
for (let i = paramListDetails.firstKeywordOnlyIndex; i < paramListDetails.params.length; i++) {
const paramInfo = paramListDetails.params[i];
if (paramInfo.param.category === ParameterCategory.Simple && paramInfo.param.name) {
paramMap.set(paramInfo.param.name, i);
}
if (newMethodMember) {
const newMethodType = getTypeOfMember(newMethodMember);
if (isFunction(newMethodType)) {
const paramListDetails = getParameterListDetails(newMethodType);

if (paramListDetails.firstKeywordOnlyIndex !== undefined) {
// Build a map of the keyword-only parameters.
const paramMap = new Map<string, number>();
for (let i = paramListDetails.firstKeywordOnlyIndex; i < paramListDetails.params.length; i++) {
const paramInfo = paramListDetails.params[i];
if (paramInfo.param.category === ParameterCategory.Simple && paramInfo.param.name) {
paramMap.set(paramInfo.param.name, i);
}
}

argList.forEach((arg) => {
const signatureTracker = new UniqueSignatureTracker();

if (arg.argumentCategory === ArgumentCategory.Simple && arg.name) {
const paramIndex = paramMap.get(arg.name.value) ?? paramListDetails.kwargsIndex;

if (paramIndex !== undefined) {
const paramInfo = paramListDetails.params[paramIndex];
const argParam: ValidateArgTypeParams = {
paramCategory: paramInfo.param.category,
paramType: FunctionType.getEffectiveParameterType(
newMethodType,
paramInfo.index
),
requiresTypeVarMatching: false,
argument: arg,
errorNode: arg.valueExpression ?? errorNode,
};

validateArgType(
argParam,
new TypeVarContext(),
signatureTracker,
{ type: newMethodType },
{ skipUnknownArgCheck: true, skipOverloadArg: true }
);
paramMap.delete(arg.name.value);
} else {
addDiagnostic(
AnalyzerNodeInfo.getFileInfo(node).diagnosticRuleSet.reportGeneralTypeIssues,
DiagnosticRule.reportGeneralTypeIssues,
Localizer.Diagnostic.paramNameMissing().format({ name: arg.name.value }),
arg.name ?? errorNode
);
}
}
});
argList.forEach((arg) => {
const signatureTracker = new UniqueSignatureTracker();

if (arg.argumentCategory === ArgumentCategory.Simple && arg.name) {
const paramIndex = paramMap.get(arg.name.value) ?? paramListDetails.kwargsIndex;

// See if we have any remaining unmatched parameters without
// default values.
const unassignedParams: string[] = [];
paramMap.forEach((index, paramName) => {
const paramInfo = paramListDetails.params[index];
if (!paramInfo.param.hasDefault) {
unassignedParams.push(paramName);
if (paramIndex !== undefined) {
const paramInfo = paramListDetails.params[paramIndex];
const argParam: ValidateArgTypeParams = {
paramCategory: paramInfo.param.category,
paramType: FunctionType.getEffectiveParameterType(newMethodType, paramInfo.index),
requiresTypeVarMatching: false,
argument: arg,
errorNode: arg.valueExpression ?? errorNode,
};

validateArgType(
argParam,
new TypeVarContext(),
signatureTracker,
{ type: newMethodType },
{ skipUnknownArgCheck: true, skipOverloadArg: true }
);
paramMap.delete(arg.name.value);
} else {
addDiagnostic(
AnalyzerNodeInfo.getFileInfo(node).diagnosticRuleSet.reportGeneralTypeIssues,
DiagnosticRule.reportGeneralTypeIssues,
Localizer.Diagnostic.paramNameMissing().format({ name: arg.name.value }),
arg.name ?? errorNode
);
}
});
}
});

if (unassignedParams.length > 0) {
const missingParamNames = unassignedParams.map((p) => `"${p}"`).join(', ');
addDiagnostic(
AnalyzerNodeInfo.getFileInfo(errorNode).diagnosticRuleSet.reportGeneralTypeIssues,
DiagnosticRule.reportGeneralTypeIssues,
unassignedParams.length === 1
? Localizer.Diagnostic.argMissingForParam().format({ name: missingParamNames })
: Localizer.Diagnostic.argMissingForParams().format({ names: missingParamNames }),
errorNode
// See if we have any remaining unmatched parameters without
// default values.
const unassignedParams: string[] = [];
paramMap.forEach((index, paramName) => {
const paramInfo = paramListDetails.params[index];
if (!paramInfo.param.hasDefault) {
unassignedParams.push(paramName);
}
});

if (unassignedParams.length > 0) {
const missingParamNames = unassignedParams.map((p) => `"${p}"`).join(', ');
addDiagnostic(
AnalyzerNodeInfo.getFileInfo(errorNode).diagnosticRuleSet.reportGeneralTypeIssues,
DiagnosticRule.reportGeneralTypeIssues,
unassignedParams.length === 1
? Localizer.Diagnostic.argMissingForParam().format({ name: missingParamNames })
: Localizer.Diagnostic.argMissingForParams().format({ names: missingParamNames }),
errorNode
);
}
}
}
} else {
// If there was no custom metaclass __new__ method, see if there is an __init_subclass__
// method present somewhere in the class hierarchy.
const initSubclassMethodInfo = getTypeOfBoundMember(
errorNode,
classType,
'__init_subclass__',
/* usage */ undefined,
/* diag */ undefined,
MemberAccessFlags.SkipClassMembers |
MemberAccessFlags.SkipOriginalClass |
MemberAccessFlags.SkipAttributeAccessOverride
);

if (initSubclassMethodInfo) {
const initSubclassMethodType = initSubclassMethodInfo.type;

if (initSubclassMethodType && initSubclassMethodInfo.classType) {
const callResult = validateCallArguments(
errorNode,
argList,
{ type: initSubclassMethodType },
/* typeVarContext */ undefined,
/* skipUnknownArgCheck */ false,
makeInferenceContext(getNoneType())
);

if (callResult.argumentErrors) {
const diag = addDiagnostic(
AnalyzerNodeInfo.getFileInfo(errorNode).diagnosticRuleSet.reportGeneralTypeIssues,
DiagnosticRule.reportGeneralTypeIssues,
Localizer.Diagnostic.initSubclassCallFailed(),
node.name
);

const initSubclassFunction = isOverloadedFunction(initSubclassMethodType)
? OverloadedFunctionType.getOverloads(initSubclassMethodType)[0]
: initSubclassMethodType;
const initSubclassDecl = isFunction(initSubclassFunction)
? initSubclassFunction.details.declaration
: undefined;

if (diag && initSubclassDecl) {
diag.addRelatedInfo(
Localizer.DiagnosticAddendum.initSubclassLocation().format({
name: printType(convertToInstance(initSubclassMethodInfo.classType)),
}),
initSubclassDecl.path,
initSubclassDecl.range
);
}
}
Expand Down
3 changes: 3 additions & 0 deletions packages/pyright-internal/src/localization/localize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,7 @@ export namespace Localizer {
export const inconsistentTabs = () => getRawString('Diagnostic.inconsistentTabs');
export const initMustReturnNone = () => getRawString('Diagnostic.initMustReturnNone');
export const initSubclassClsParam = () => getRawString('Diagnostic.initSubclassClsParam');
export const initSubclassCallFailed = () => getRawString('Diagnostic.initSubclassCallFailed');
export const instanceMethodSelfParam = () => getRawString('Diagnostic.instanceMethodSelfParam');
export const instanceVarOverridesClassVar = () =>
new ParameterizedString<{ name: string; className: string }>(
Expand Down Expand Up @@ -1199,6 +1200,8 @@ export namespace Localizer {
new ParameterizedString<{ type: string }>(getRawString('DiagnosticAddendum.initMethodLocation'));
export const initMethodSignature = () =>
new ParameterizedString<{ type: string }>(getRawString('DiagnosticAddendum.initMethodSignature'));
export const initSubclassLocation = () =>
new ParameterizedString<{ name: string }>(getRawString('DiagnosticAddendum.initSubclassLocation'));
export const invariantSuggestionDict = () => getRawString('DiagnosticAddendum.invariantSuggestionDict');
export const invariantSuggestionList = () => getRawString('DiagnosticAddendum.invariantSuggestionList');
export const functionTooManyParams = () =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@
"initMustReturnNone": "Return type of \"__init__\" must be None",
"inconsistentTabs": "Inconsistent use of tabs and spaces in indentation",
"initSubclassClsParam": "__init_subclass__ override should take a \"cls\" parameter",
"initSubclassCallFailed": "Incorrect keyword arguments for __init_subclass__ method",
"instanceMethodSelfParam": "Instance methods should take a \"self\" parameter",
"instanceVarOverridesClassVar": "Instance variable \"{name}\" overrides class variable of same name in class \"{className}\"",
"instantiateAbstract": "Cannot instantiate abstract class \"{type}\"",
Expand Down Expand Up @@ -613,6 +614,7 @@
"initMethodLocation": "The __init__ method is defined in class \"{type}\"",
"incompatibleDeleter": "Property deleter method is incompatible",
"initMethodSignature": "Signature of __init__ is \"{type}\"",
"initSubclassLocation": "The __init_subclass__ method is defined in class \"{name}\"",
"invariantSuggestionDict": "Consider switching from \"dict\" to \"Mapping\" which is covariant in the value type",
"invariantSuggestionList": "Consider switching from \"list\" to \"Sequence\" which is covariant",
"kwargsParamMissing": "Parameter \"**{paramName}\" has no corresponding parameter",
Expand Down
10 changes: 9 additions & 1 deletion packages/pyright-internal/src/tests/samples/classes1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# handle various class definition cases.


from typing import Any


class A:
...

Expand All @@ -17,7 +20,12 @@ class D(app.C):
...


class E:
class EMeta(type):
def __new__(mcls, *args: Any, **kwargs: Any):
...


class E(metaclass=EMeta):
pass


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ class Customer1:
name: str


@create_model
@create_model(frozen=True)
class Customer2:
id: int
name: str


@create_model
class Customer2Subclass(Customer2, frozen=True):
@create_model(frozen=True)
class Customer2Subclass(Customer2):
salary: float


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def model_field(
class ModelMeta(type):
not_a_field: str


class ModelBase(metaclass=ModelMeta):
def __init_subclass__(
cls,
*,
Expand All @@ -35,10 +37,6 @@ def __init_subclass__(
...


class ModelBase(metaclass=ModelMeta):
...


class Customer1(ModelBase, frozen=True):
id: int = model_field()
name: str = model_field()
Expand Down
12 changes: 9 additions & 3 deletions packages/pyright-internal/src/tests/samples/initsubclass1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# PEP 487.

from datetime import datetime
from typing import Any, Optional, Type
from typing import Any, Optional, Type, TypedDict


class ClassA:
Expand All @@ -13,13 +13,13 @@ def __init_subclass__(
super().__init_subclass__()


# This should generate an error because param1 is
# This should generate two errors because param1 is
# the wrong type.
class ClassB(ClassA, param1=0, param2=4):
pass


# This should generate an error because param2 is missing.
# This should generate two errors because param2 is missing.
class ClassC(ClassA, param1="0", param3=datetime.now()):
pass

Expand Down Expand Up @@ -56,3 +56,9 @@ class ClassG:

class ClassH(ClassG):
pass


# This should generate two errors because "a" is not present
# in the object.__init_subclass__ method.
class ClassI(a=3):
a: int
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ class B(A, param_a=123):
pass


# This should generate an error because param_a is missing
# This should generate two errors because param_a is missing.
class C(B):
pass
4 changes: 2 additions & 2 deletions packages/pyright-internal/src/tests/typeEvaluator3.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1370,13 +1370,13 @@ test('EmptyContainers1', () => {
test('InitSubclass1', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['initsubclass1.py']);

TestUtils.validateResults(analysisResults, 2);
TestUtils.validateResults(analysisResults, 6);
});

test('InitSubclass2', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['initsubclass2.py']);

TestUtils.validateResults(analysisResults, 1);
TestUtils.validateResults(analysisResults, 2);
});

test('None1', () => {
Expand Down

0 comments on commit de7625f

Please sign in to comment.