From 81276ef9a1e80e5a17b0c98b8d1126b5864239cf Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 17 Aug 2023 00:05:46 -0700 Subject: [PATCH] Fix dask bugs (#507) * Fix dask bugs * update --- fugue_dask/_utils.py | 7 +++- fugue_ray/_utils/io.py | 2 +- tests/fugue_dask/test_execution_engine.py | 45 ++++++++++++++++++++++- 3 files changed, 49 insertions(+), 5 deletions(-) diff --git a/fugue_dask/_utils.py b/fugue_dask/_utils.py index 09c79c02..466faa58 100644 --- a/fugue_dask/_utils.py +++ b/fugue_dask/_utils.py @@ -41,6 +41,7 @@ def hash_repartition(df: dd.DataFrame, num: int, cols: List[Any]) -> dd.DataFram return df if num == 1: return df.repartition(1) + df = df.reset_index(drop=True).clear_divisions() idf, ct = _add_hash_index(df, num, cols) return _postprocess(idf, ct, num) @@ -63,9 +64,10 @@ def even_repartition(df: dd.DataFrame, num: int, cols: List[Any]) -> dd.DataFram """ if num == 1: return df.repartition(1) + if len(cols) == 0 and num <= 0: + return df + df = df.reset_index(drop=True).clear_divisions() if len(cols) == 0: - if num <= 0: - return df idf, ct = _add_continuous_index(df) else: idf, ct = _add_group_index(df, cols, shuffle=False) @@ -97,6 +99,7 @@ def rand_repartition( return df if num == 1: return df.repartition(1) + df = df.reset_index(drop=True).clear_divisions() if len(cols) == 0: idf, ct = _add_random_index(df, num=num, seed=seed) else: diff --git a/fugue_ray/_utils/io.py b/fugue_ray/_utils/io.py index deb66b4e..6198b81c 100644 --- a/fugue_ray/_utils/io.py +++ b/fugue_ray/_utils/io.py @@ -100,7 +100,7 @@ def _save_csv( if "header" in kw: kw["include_header"] = kw.pop("header") - def _fn() -> Dict[str, Any]: + def _fn() -> Dict[str, Any]: # pragma: no cover return dict(write_options=pacsv.WriteOptions(**kw)) df.native.write_csv( diff --git a/tests/fugue_dask/test_execution_engine.py b/tests/fugue_dask/test_execution_engine.py index 2f6889d2..5e0c11d3 100644 --- a/tests/fugue_dask/test_execution_engine.py +++ b/tests/fugue_dask/test_execution_engine.py @@ -2,8 +2,8 @@ from threading import RLock from typing import Any, List, Optional -import dask import dask.dataframe as dd +import numpy as np import pandas as pd import pytest from dask.distributed import Client @@ -25,7 +25,6 @@ from fugue_test.builtin_suite import BuiltInTests from fugue_test.execution_suite import ExecutionEngineTests - _CONF = { "fugue.rpc.server": "fugue.rpc.flask.FlaskRPCServer", "fugue.rpc.flask_server.host": "127.0.0.1", @@ -321,6 +320,48 @@ def tr(df: List[List[Any]], add: Optional[callable]) -> List[List[Any]]: assert 5 == cb.n +def test_multiple_transforms(fugue_dask_client): + def t1(df: pd.DataFrame) -> pd.DataFrame: + return pd.concat([df, df]) + + def t2(df: pd.DataFrame) -> pd.DataFrame: + return ( + df.groupby(["a", "b"], as_index=False, dropna=False) + .apply(lambda x: x.head(1)) + .reset_index(drop=True) + ) + + def compute(df: pd.DataFrame, engine) -> pd.DataFrame: + with fa.engine_context(engine): + ddf = fa.as_fugue_df(df) + ddf1 = fa.transform(ddf, t1, schema="*", partition=dict(algo="hash")) + ddf2 = fa.transform( + ddf1, + t2, + schema="*", + partition=dict(by=["a", "b"], presort="c", algo="coarse", num=2), + ) + return ( + ddf2.as_pandas() + .astype("float64") + .fillna(float("nan")) + .sort_values(["a", "b"]) + ) + + np.random.seed(0) + df = pd.DataFrame( + dict( + a=np.random.randint(1, 5, 1000), + b=np.random.choice([1, 2, 3, None], 1000), + c=np.random.rand(1000), + ) + ) + + actual = compute(df, fugue_dask_client) + expected = compute(df, None) + assert np.allclose(actual, expected, equal_nan=True) + + @transformer("ct:long") def count_partition(df: List[List[Any]]) -> List[List[Any]]: return [[len(df)]]