Skip to content

Commit

Permalink
Fix enum truthiness for StrEnum (#18379)
Browse files Browse the repository at this point in the history
  • Loading branch information
hauntsaninja authored Dec 30, 2024
1 parent 60bff6c commit e139a0d
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 18 deletions.
12 changes: 5 additions & 7 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,19 +648,14 @@ def _remove_redundant_union_items(items: list[Type], keep_erased: bool) -> list[
return items


def _get_type_method_ret_type(t: Type, *, name: str) -> Type | None:
t = get_proper_type(t)

def _get_type_method_ret_type(t: ProperType, *, name: str) -> Type | None:
# For Enum literals the ret_type can change based on the Enum
# we need to check the type of the enum rather than the literal
if isinstance(t, LiteralType) and t.is_enum_literal():
t = t.fallback

if isinstance(t, Instance):
sym = t.type.get(name)
# Fallback to the metaclass for the lookup when necessary
if not sym and (m := t.type.metaclass_type):
sym = m.type.get(name)
if sym:
sym_type = get_proper_type(sym.type)
if isinstance(sym_type, CallableType):
Expand Down Expand Up @@ -733,7 +728,10 @@ def false_only(t: Type) -> ProperType:
if ret_type:
if not ret_type.can_be_false:
return UninhabitedType(line=t.line)
elif isinstance(t, Instance) and t.type.is_final:
elif isinstance(t, Instance):
if t.type.is_final or t.type.is_enum:
return UninhabitedType(line=t.line)
elif isinstance(t, LiteralType) and t.is_enum_literal():
return UninhabitedType(line=t.line)

new_t = copy_type(t)
Expand Down
93 changes: 83 additions & 10 deletions test-data/unit/check-enum.test
Original file line number Diff line number Diff line change
Expand Up @@ -181,27 +181,100 @@ def infer_truth(truth: Truth) -> None:
[case testEnumTruthyness]
# mypy: warn-unreachable
import enum
from typing_extensions import Literal

class E(enum.Enum):
x = 0
if not E.x:
"noop"
zero = 0
one = 1

def print(s: str) -> None: ...

if E.zero:
print("zero is true")
if not E.zero:
print("zero is false") # E: Statement is unreachable

if E.one:
print("one is true")
if not E.one:
print("one is false") # E: Statement is unreachable

def main(zero: Literal[E.zero], one: Literal[E.one]) -> None:
if zero:
print("zero is true")
if not zero:
print("zero is false") # E: Statement is unreachable
if one:
print("one is true")
if not one:
print("one is false") # E: Statement is unreachable
[builtins fixtures/tuple.pyi]
[out]
main:6: error: Statement is unreachable

[case testEnumTruthynessCustomDunderBool]
# mypy: warn-unreachable
import enum
from typing_extensions import Literal

class E(enum.Enum):
x = 0
zero = 0
one = 1
def __bool__(self) -> Literal[False]:
return False
if E.x:
"noop"

def print(s: str) -> None: ...

if E.zero:
print("zero is true") # E: Statement is unreachable
if not E.zero:
print("zero is false")

if E.one:
print("one is true") # E: Statement is unreachable
if not E.one:
print("one is false")

def main(zero: Literal[E.zero], one: Literal[E.one]) -> None:
if zero:
print("zero is true") # E: Statement is unreachable
if not zero:
print("zero is false")
if one:
print("one is true") # E: Statement is unreachable
if not one:
print("one is false")
[builtins fixtures/enum.pyi]

[case testEnumTruthynessStrEnum]
# mypy: warn-unreachable
import enum
from typing_extensions import Literal

class E(enum.StrEnum):
empty = ""
not_empty = "asdf"

def print(s: str) -> None: ...

if E.empty:
print("empty is true")
if not E.empty:
print("empty is false")

if E.not_empty:
print("not_empty is true")
if not E.not_empty:
print("not_empty is false")

def main(empty: Literal[E.empty], not_empty: Literal[E.not_empty]) -> None:
if empty:
print("empty is true")
if not empty:
print("empty is false")
if not_empty:
print("not_empty is true")
if not not_empty:
print("not_empty is false")
[builtins fixtures/enum.pyi]
[out]
main:9: error: Statement is unreachable

[case testEnumUnique]
import enum
Expand Down
4 changes: 3 additions & 1 deletion test-data/unit/fixtures/enum.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ class tuple(Generic[T]):
def __getitem__(self, x: int) -> T: pass

class int: pass
class str: pass
class str:
def __len__(self) -> int: pass

class dict: pass
class ellipsis: pass

0 comments on commit e139a0d

Please sign in to comment.