diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 55677d0a..afd53921 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -38,7 +38,7 @@ ], "postCreateCommand": "make devenv", "features": { - "ghcr.io/devcontainers/features/docker-in-docker:2": {}, + "ghcr.io/devcontainers/features/docker-in-docker:2.11.0": {}, "ghcr.io/devcontainers/features/java:1": { "version": "11" } diff --git a/fugue/dataframe/function_wrapper.py b/fugue/dataframe/function_wrapper.py index 342d103c..f6571f22 100644 --- a/fugue/dataframe/function_wrapper.py +++ b/fugue/dataframe/function_wrapper.py @@ -20,6 +20,7 @@ PositionalParam, function_wrapper, ) +from triad.utils.convert import compare_annotations from triad.utils.iter import EmptyAwareIterable, make_empty_aware from ..constants import FUGUE_ENTRYPOINT @@ -37,6 +38,14 @@ from .pandas_dataframe import PandasDataFrame +def _compare_iter(tp: Any) -> Any: + return lambda x: compare_annotations( + x, Iterable[tp] # type:ignore + ) or compare_annotations( + x, Iterator[tp] # type:ignore + ) + + @function_wrapper(FUGUE_ENTRYPOINT) class DataFrameFunctionWrapper(FunctionWrapper): @property @@ -228,10 +237,7 @@ def count(self, df: List[List[Any]]) -> int: return len(df) -@fugue_annotated_param( - Iterable[List[Any]], - matcher=lambda x: x == Iterable[List[Any]] or x == Iterator[List[Any]], -) +@fugue_annotated_param(Iterable[List[Any]], matcher=_compare_iter(List[Any])) class _IterableListParam(_LocalNoSchemaDataFrameParam): @no_type_check def to_input_data(self, df: DataFrame, ctx: Any) -> Iterable[List[Any]]: @@ -288,10 +294,7 @@ def count(self, df: List[Dict[str, Any]]) -> int: return len(df) -@fugue_annotated_param( - Iterable[Dict[str, Any]], - matcher=lambda x: x == Iterable[Dict[str, Any]] or x == Iterator[Dict[str, Any]], -) +@fugue_annotated_param(Iterable[Dict[str, Any]], matcher=_compare_iter(Dict[str, Any])) class _IterableDictParam(_LocalNoSchemaDataFrameParam): @no_type_check def to_input_data(self, df: DataFrame, ctx: Any) -> Iterable[Dict[str, Any]]: @@ -360,10 +363,7 @@ def format_hint(self) -> Optional[str]: return "pandas" -@fugue_annotated_param( - Iterable[pd.DataFrame], - matcher=lambda x: x == Iterable[pd.DataFrame] or x == Iterator[pd.DataFrame], -) +@fugue_annotated_param(Iterable[pd.DataFrame], matcher=_compare_iter(pd.DataFrame)) class _IterablePandasParam(LocalDataFrameParam): @no_type_check def to_input_data(self, df: DataFrame, ctx: Any) -> Iterable[pd.DataFrame]: @@ -419,10 +419,7 @@ def format_hint(self) -> Optional[str]: return "pyarrow" -@fugue_annotated_param( - Iterable[pa.Table], - matcher=lambda x: x == Iterable[pa.Table] or x == Iterator[pa.Table], -) +@fugue_annotated_param(Iterable[pa.Table], matcher=_compare_iter(pa.Table)) class _IterableArrowParam(LocalDataFrameParam): @no_type_check def to_input_data(self, df: DataFrame, ctx: Any) -> Iterable[pa.Table]: diff --git a/fugue_spark/_utils/misc.py b/fugue_spark/_utils/misc.py index 78b28041..37ea7e79 100644 --- a/fugue_spark/_utils/misc.py +++ b/fugue_spark/_utils/misc.py @@ -3,7 +3,7 @@ try: from pyspark.sql.connect.session import SparkSession as SparkConnectSession from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataFrame -except ImportError: # pragma: no cover +except Exception: # pragma: no cover SparkConnectSession = None SparkConnectDataFrame = None import pyspark.sql as ps diff --git a/fugue_version/__init__.py b/fugue_version/__init__.py index 3e2f46a3..d69d16e9 100644 --- a/fugue_version/__init__.py +++ b/fugue_version/__init__.py @@ -1 +1 @@ -__version__ = "0.9.0" +__version__ = "0.9.1" diff --git a/tests/fugue/dataframe/test_function_wrapper.py b/tests/fugue/dataframe/test_function_wrapper.py index fa634920..4cced3e9 100644 --- a/tests/fugue/dataframe/test_function_wrapper.py +++ b/tests/fugue/dataframe/test_function_wrapper.py @@ -26,10 +26,14 @@ ) from fugue.dataframe.utils import _df_eq as df_eq from fugue.dev import DataFrameFunctionWrapper +import sys def test_function_wrapper(): - for f in [f20, f21, f212, f22, f23, f24, f25, f26, f30, f31, f32, f35, f36]: + fs = [f20, f21, f212, f22, f23, f24, f25, f26, f30, f31, f32, f35, f36] + if sys.version_info >= (3, 9): + fs.append(f33) + for f in fs: df = ArrayDataFrame([[0]], "a:int") w = DataFrameFunctionWrapper(f, "^[ldsp][ldsp]$", "[ldspq]") res = w.run([df], dict(a=df), ignore_unknown=False, output_schema="a:int") @@ -372,6 +376,14 @@ def f32( return ArrayDataFrame(arr, "a:int").as_dict_iterable() +def f33( + e: list[dict[str, Any]], a: Iterable[dict[str, Any]] +) -> EmptyAwareIterable[Dict[str, Any]]: + e += list(a) + arr = [[x["a"]] for x in e] + return ArrayDataFrame(arr, "a:int").as_dict_iterable() + + def f35(e: pd.DataFrame, a: LocalDataFrame) -> Iterable[pd.DataFrame]: e = PandasDataFrame(e, "a:int").as_pandas() a = ArrayDataFrame(a, "a:int").as_pandas()