Skip to content

Commit

Permalink
refactor match code
Browse files Browse the repository at this point in the history
  • Loading branch information
adhami3310 committed Feb 13, 2025
1 parent d1ff6d5 commit c743139
Showing 1 changed file with 24 additions and 26 deletions.
50 changes: 24 additions & 26 deletions reflex/components/core/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing_extensions import Unpack

from reflex.components.base import Fragment
from reflex.components.component import BaseComponent, Component
from reflex.components.component import BaseComponent
from reflex.utils import types
from reflex.utils.exceptions import MatchTypeError
from reflex.vars.base import VAR_TYPE, Var
Expand Down Expand Up @@ -36,11 +36,14 @@ def _process_match_cases(cases: tuple[CASE_TYPE[VAR_TYPE], ...]):
)


def _validate_return_types(match_cases: tuple[CASE_TYPE[VAR_TYPE], ...]) -> None:
def _validate_return_types(*return_values: Any) -> bool:
"""Validate that match cases have the same return types.
Args:
match_cases: The match cases.
return_values: The return values of the match cases.
Returns:
True if all cases have the same return types.
Raises:
MatchTypeError: If the return types of cases are different.
Expand All @@ -54,22 +57,20 @@ def is_component_or_component_var(obj: Any) -> bool:
)
)

def type_of_return_type(obj: Any) -> Any:
def type_of_return_value(obj: Any) -> Any:
if isinstance(obj, Var):
return obj._var_type
return type(obj)

return_types = [case[-1] for case in match_cases]
is_return_type_component = [
is_component_or_component_var(return_type) for return_type in return_values
]

if any(
is_component_or_component_var(return_type) for return_type in return_types
) and not all(
is_component_or_component_var(return_type) for return_type in return_types
):
if any(is_return_type_component) and not all(is_return_type_component):
non_component_return_types = [
(type_of_return_type(return_type), i)
for i, return_type in enumerate(return_types)
if not is_component_or_component_var(return_type)
(type_of_return_value(return_value), i)
for i, return_value in enumerate(return_values)
if not is_return_type_component[i]
]
raise MatchTypeError(
"Match cases should have the same return types. "
Expand All @@ -82,6 +83,8 @@ def type_of_return_type(obj: Any) -> Any:
)
)

return all(is_return_type_component)


def _create_match_var(
match_cond_var: Var,
Expand Down Expand Up @@ -119,7 +122,7 @@ def match(
Raises:
ValueError: If the default case is not the last case or the tuple elements are less than 2.
"""
default = None
default = types.Unset()

if len([case for case in cases if not isinstance(case, tuple)]) > 1:
raise ValueError("rx.match can only have one default case.")
Expand All @@ -136,22 +139,17 @@ def match(

_process_match_cases(actual_cases)

_validate_return_types(actual_cases)
is_component_match = _validate_return_types(
*[case[-1] for case in actual_cases],
*([default] if not isinstance(default, types.Unset) else []),
)

if default is None and any(
not (
isinstance((return_type := case[-1]), Component)
or (
isinstance(return_type, Var)
and types.typehint_issubclass(return_type._var_type, Component)
)
)
for case in actual_cases
):
if isinstance(default, types.Unset) and not is_component_match:
raise ValueError(
"For cases with return types as Vars, a default case must be provided"
)
elif default is None:

if isinstance(default, types.Unset):
default = Fragment.create()

default = cast(Var[VAR_TYPE] | VAR_TYPE, default)
Expand Down

0 comments on commit c743139

Please sign in to comment.