Skip to content

Commit

Permalink
Add more supported type annotations, fix spark connect issue
Browse files Browse the repository at this point in the history
  • Loading branch information
goodwanghan authored Jun 12, 2024
1 parent 48b7ab6 commit d86c3ac
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
],
"postCreateCommand": "make devenv",
"features": {
"ghcr.io/devcontainers/features/docker-in-docker:2": {},
"ghcr.io/devcontainers/features/docker-in-docker:2.11.0": {},
"ghcr.io/devcontainers/features/java:1": {
"version": "11"
}
Expand Down
29 changes: 13 additions & 16 deletions fugue/dataframe/function_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
PositionalParam,
function_wrapper,
)
from triad.utils.convert import compare_annotations
from triad.utils.iter import EmptyAwareIterable, make_empty_aware

from ..constants import FUGUE_ENTRYPOINT
Expand All @@ -37,6 +38,14 @@
from .pandas_dataframe import PandasDataFrame


def _compare_iter(tp: Any) -> Any:
return lambda x: compare_annotations(
x, Iterable[tp] # type:ignore
) or compare_annotations(
x, Iterator[tp] # type:ignore
)


@function_wrapper(FUGUE_ENTRYPOINT)
class DataFrameFunctionWrapper(FunctionWrapper):
@property
Expand Down Expand Up @@ -228,10 +237,7 @@ def count(self, df: List[List[Any]]) -> int:
return len(df)


@fugue_annotated_param(
Iterable[List[Any]],
matcher=lambda x: x == Iterable[List[Any]] or x == Iterator[List[Any]],
)
@fugue_annotated_param(Iterable[List[Any]], matcher=_compare_iter(List[Any]))
class _IterableListParam(_LocalNoSchemaDataFrameParam):
@no_type_check
def to_input_data(self, df: DataFrame, ctx: Any) -> Iterable[List[Any]]:
Expand Down Expand Up @@ -288,10 +294,7 @@ def count(self, df: List[Dict[str, Any]]) -> int:
return len(df)


@fugue_annotated_param(
Iterable[Dict[str, Any]],
matcher=lambda x: x == Iterable[Dict[str, Any]] or x == Iterator[Dict[str, Any]],
)
@fugue_annotated_param(Iterable[Dict[str, Any]], matcher=_compare_iter(Dict[str, Any]))
class _IterableDictParam(_LocalNoSchemaDataFrameParam):
@no_type_check
def to_input_data(self, df: DataFrame, ctx: Any) -> Iterable[Dict[str, Any]]:
Expand Down Expand Up @@ -360,10 +363,7 @@ def format_hint(self) -> Optional[str]:
return "pandas"


@fugue_annotated_param(
Iterable[pd.DataFrame],
matcher=lambda x: x == Iterable[pd.DataFrame] or x == Iterator[pd.DataFrame],
)
@fugue_annotated_param(Iterable[pd.DataFrame], matcher=_compare_iter(pd.DataFrame))
class _IterablePandasParam(LocalDataFrameParam):
@no_type_check
def to_input_data(self, df: DataFrame, ctx: Any) -> Iterable[pd.DataFrame]:
Expand Down Expand Up @@ -419,10 +419,7 @@ def format_hint(self) -> Optional[str]:
return "pyarrow"


@fugue_annotated_param(
Iterable[pa.Table],
matcher=lambda x: x == Iterable[pa.Table] or x == Iterator[pa.Table],
)
@fugue_annotated_param(Iterable[pa.Table], matcher=_compare_iter(pa.Table))
class _IterableArrowParam(LocalDataFrameParam):
@no_type_check
def to_input_data(self, df: DataFrame, ctx: Any) -> Iterable[pa.Table]:
Expand Down
2 changes: 1 addition & 1 deletion fugue_spark/_utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
try:
from pyspark.sql.connect.session import SparkSession as SparkConnectSession
from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataFrame
except ImportError: # pragma: no cover
except Exception: # pragma: no cover
SparkConnectSession = None
SparkConnectDataFrame = None
import pyspark.sql as ps
Expand Down
2 changes: 1 addition & 1 deletion fugue_version/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.9.0"
__version__ = "0.9.1"
14 changes: 13 additions & 1 deletion tests/fugue/dataframe/test_function_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,14 @@
)
from fugue.dataframe.utils import _df_eq as df_eq
from fugue.dev import DataFrameFunctionWrapper
import sys


def test_function_wrapper():
for f in [f20, f21, f212, f22, f23, f24, f25, f26, f30, f31, f32, f35, f36]:
fs = [f20, f21, f212, f22, f23, f24, f25, f26, f30, f31, f32, f35, f36]
if sys.version_info >= (3, 9):
fs.append(f33)
for f in fs:
df = ArrayDataFrame([[0]], "a:int")
w = DataFrameFunctionWrapper(f, "^[ldsp][ldsp]$", "[ldspq]")
res = w.run([df], dict(a=df), ignore_unknown=False, output_schema="a:int")
Expand Down Expand Up @@ -372,6 +376,14 @@ def f32(
return ArrayDataFrame(arr, "a:int").as_dict_iterable()


def f33(
e: list[dict[str, Any]], a: Iterable[dict[str, Any]]
) -> EmptyAwareIterable[Dict[str, Any]]:
e += list(a)
arr = [[x["a"]] for x in e]
return ArrayDataFrame(arr, "a:int").as_dict_iterable()


def f35(e: pd.DataFrame, a: LocalDataFrame) -> Iterable[pd.DataFrame]:
e = PandasDataFrame(e, "a:int").as_pandas()
a = ArrayDataFrame(a, "a:int").as_pandas()
Expand Down

0 comments on commit d86c3ac

Please sign in to comment.