Skip to content

Commit

Permalink
Added support for bidirectional type inference for assignment stateme…
Browse files Browse the repository at this point in the history
…nts that are assigning to an index expression that is subscripted by a slice. This addresses #9564.
  • Loading branch information
erictraut committed Jan 14, 2025
1 parent 575c639 commit 888ab68
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 8 deletions.
41 changes: 34 additions & 7 deletions packages/pyright-internal/src/analyzer/typeEvaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2761,13 +2761,7 @@ export function createTypeEvaluator(
);

if (baseType && isClassInstance(baseType)) {
const setItemType = getBoundMagicMethod(baseType, '__setitem__');
if (setItemType && isFunction(setItemType) && setItemType.shared.parameters.length >= 2) {
const paramType = FunctionType.getParamType(setItemType, 1);
if (!isAnyOrUnknown(paramType)) {
return paramType;
}
} else if (ClassType.isTypedDictClass(baseType)) {
if (ClassType.isTypedDictClass(baseType)) {
const typeFromTypedDict = getTypeOfIndexedTypedDict(
evaluatorInterface,
expression,
Expand All @@ -2778,6 +2772,39 @@ export function createTypeEvaluator(
return typeFromTypedDict.type;
}
}

let setItemType = getBoundMagicMethod(baseType, '__setitem__');
if (!setItemType) {
break;
}

if (isOverloaded(setItemType)) {
// Determine whether we need to use the slice overload.
const expectsSlice =
expression.d.items.length === 1 &&
expression.d.items[0].d.valueExpr.nodeType === ParseNodeType.Slice;
const overloads = OverloadedType.getOverloads(setItemType);
setItemType = overloads.find((overload) => {
if (overload.shared.parameters.length < 2) {
return false;
}

const keyType = FunctionType.getParamType(overload, 0);
const isSlice = isClassInstance(keyType) && ClassType.isBuiltIn(keyType, 'slice');
return expectsSlice === isSlice;
});

if (!setItemType) {
break;
}
}

if (isFunction(setItemType) && setItemType.shared.parameters.length >= 2) {
const paramType = FunctionType.getParamType(setItemType, 1);
if (!isAnyOrUnknown(paramType)) {
return paramType;
}
}
}
break;
}
Expand Down
7 changes: 6 additions & 1 deletion packages/pyright-internal/src/tests/samples/index1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# when used with the __getitem__ and __setitem__ method.


from typing import Generic, Self, Type, TypeVar, Any
from typing import Generic, Literal, Self, Type, TypeVar, Any


class MyInt:
Expand Down Expand Up @@ -121,3 +121,8 @@ class ClassI:


reveal_type(ClassI()[0], expected_text="ClassH")


def func4(l: list[Literal["a", "b"]]):
l[0] = "a"
l[0:0] = ["a", "b"]

0 comments on commit 888ab68

Please sign in to comment.