Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
goodwanghan committed Feb 14, 2024
1 parent 31d01ed commit e6c3591
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 76 deletions.
72 changes: 47 additions & 25 deletions fugue_snowflake/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,11 @@
from snowflake.connector.result_batch import ResultBatch
from triad import Schema
from triad.utils.pyarrow import (
TRIAD_DEFAULT_TIMESTAMP,
get_alter_func,
parse_json_columns,
replace_types_in_table,
)


_PA_TYPE_TO_SF_TYPE: Dict[pa.DataType, str] = {
pa.string(): "STRING",
pa.bool_(): "BOOLEAN",
Expand Down Expand Up @@ -130,12 +128,12 @@ def fix_snowflake_arrow_result(result: pa.Table) -> pa.Table:
pa.int64(),
),
(lambda tp: pa.types.is_date64(tp), pa.date32()),
(
lambda tp: pa.types.is_timestamp(tp)
and tp.tz is None
and tp != TRIAD_DEFAULT_TIMESTAMP,
TRIAD_DEFAULT_TIMESTAMP,
),
# (
# lambda tp: pa.types.is_timestamp(tp)
# and tp.tz is None
# and tp != TRIAD_DEFAULT_TIMESTAMP,
# TRIAD_DEFAULT_TIMESTAMP,
# ),
],
)

Expand All @@ -159,28 +157,40 @@ def to_snowflake_schema(schema: Any) -> str:

def get_arrow_from_batches(
batches: Optional[List[ResultBatch]],
schema: None = None,
query_output_schema: Schema,
schema: Any = None,
infer_nested_types: bool = False,
) -> pa.Table:
if batches is None or len(batches) == 0:
if schema is not None:
return (
schema if isinstance(schema, Schema) else Schema(schema)
).create_empty_arrow_table()
raise ValueError("No result")
return query_output_schema.create_empty_arrow_table()

def _batches_to_arrow(_batches: List[ResultBatch]) -> Iterable[pa.Table]:
has_result = False
for batch in _batches:
adf = batch.to_arrow()
if adf.num_rows == 0:
continue
func = get_alter_func(adf.schema, query_output_schema.pa_schema, safe=True)
has_result = True
yield func(adf)

if not has_result:
yield query_output_schema.create_empty_arrow_table()

adf = pa.concat_tables(_batches_to_arrow(batches))

nested_cols = _get_nested_columns(batches[0])
adf = pa.concat_tables([x.to_arrow() for x in batches])
if adf.num_rows == 0:
return fix_snowflake_arrow_result(adf)
if schema is None:
adf = fix_snowflake_arrow_result(adf)
if infer_nested_types and len(nested_cols) > 0:
adf = parse_json_columns(adf, nested_cols)
return adf
_schema = schema if isinstance(schema, Schema) else Schema(schema)
adf = parse_json_columns(adf, nested_cols)
func = get_alter_func(adf.schema, _schema.pa_schema, safe=True)
return func(adf)
if infer_nested_types and len(nested_cols) > 0:
adf = parse_json_columns(adf, nested_cols)
if schema is not None:
_schema = schema if isinstance(schema, Schema) else Schema(schema)
func = get_alter_func(adf.schema, _schema.pa_schema, safe=True)
adf = func(adf)
return adf


def _get_nested_columns(batch: ResultBatch) -> List[str]:
Expand All @@ -192,17 +202,29 @@ def _get_nested_columns(batch: ResultBatch) -> List[str]:
return res


def _get_batch_arrow_schema(batch: ResultBatch) -> pa.Schema:
fields = [
pa.field(s.name, FIELD_TYPES[s.type_code].pa_type()) for s in batch.schema
]
return pa.schema(fields)


def temp_rand_str() -> str:
return ("temp_" + str(uuid4()).split("-")[0]).upper()


def build_package_list(packages: Iterable[str]) -> List[str]:
ps: Set[str] = set()
for p in packages:
if "=" in p or "<" in p or ">" in p:
ps.add(p)
continue
try:
if "=" in p or "<" in p or ">" in p:
ps.add(p)
else:
ps.add(p + "==" + get_version(p))
except Exception: # pragma: no cover
ps.add(p)
else:
ps.add(p + "==" + get_version(p))
return list(ps)


Expand Down
2 changes: 0 additions & 2 deletions fugue_snowflake/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,8 @@ def load(
if not is_select_query(qot):
query = f"SELECT * FROM {qot}"
with fa.engine_context(engine, engine_conf=engine_conf, infer_by=["force_sf"]) as e:
print(e)
client = SnowflakeClient.get_or_create(fa.get_current_conf())
if isinstance(e, SnowflakeExecutionEngine):
print(qot)
res: Any = client.query_or_table_to_ibis(qot)
return SnowflakeDataFrame(res) if as_fugue else res
elif not e.is_distributed and e.get_current_parallelism() <= 1:
Expand Down
48 changes: 28 additions & 20 deletions fugue_snowflake/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(

self._parallelism = 0

self._ibis = ibis.snowflake.connect(
self._ibis = ibis.fugue_snowflake.connect(
account=account,
user=user,
password=password,
Expand Down Expand Up @@ -179,10 +179,9 @@ def get_current_parallelism(self) -> int:
elif size == "x-large":
return 16
elif "x-large" in size:
return 16 * int(size.split("x-large")[0])
return 16 ** (int(size.split("x-large")[0]) + 3)
raise NotImplementedError(f"Unknown warehouse size: {size}")
except Exception:
raise
self._parallelism = 1
return self._parallelism

Expand Down Expand Up @@ -287,11 +286,8 @@ def query_to_arrow(
infer_nested_types: bool = False,
cursor: Optional[SnowflakeCursor] = None,
) -> pa.Table:
return get_arrow_from_batches(
self.query_to_result_batches(query, cursor=cursor),
schema=schema,
infer_nested_types=infer_nested_types,
)
it = self.query_or_table_to_ibis(query)
return self.ibis_to_arrow(it, schema, infer_nested_types, cursor)

def ibis_to_arrow(
self,
Expand All @@ -301,8 +297,12 @@ def ibis_to_arrow(
cursor: Optional[SnowflakeCursor] = None,
) -> pa.Table:
query = self.ibis_to_query_or_table(table, force_query=True)
return self.query_to_arrow(
query, schema=schema, infer_nested_types=infer_nested_types, cursor=cursor
table_schema = to_schema(table.schema())
return get_arrow_from_batches(
self.query_to_result_batches(query, cursor=cursor),
query_output_schema=table_schema,
schema=schema,
infer_nested_types=infer_nested_types,
)

def query_to_engine_df(
Expand All @@ -313,13 +313,15 @@ def query_to_engine_df(
infer_nested_types: bool = False,
cursor: Optional[SnowflakeCursor] = None,
) -> DataFrame:
it = self.query_or_table_to_ibis(query)
table_schema = to_schema(it.schema())

if _is_snowflake_engine(engine):
return engine.to_df(self.query_or_table_to_ibis(query), schema=schema)
return engine.to_df(it, schema=schema)
if schema is not None:
_schema = schema if isinstance(schema, Schema) else Schema(schema)
else:
tb = self.query_or_table_to_ibis(query)
_schema = to_schema(tb.schema())
_schema = table_schema

batches = self.query_to_result_batches(query, cursor=cursor)

Expand All @@ -331,7 +333,10 @@ def query_to_engine_df(
def _map(cursor: Any, df: LocalDataFrame) -> LocalDataFrame:
_b = [batches[row["id"]] for row in df.as_dict_iterable()] # type: ignore
adf = get_arrow_from_batches(
_b, schema=schema, infer_nested_types=infer_nested_types
_b,
query_output_schema=table_schema,
schema=_schema,
infer_nested_types=infer_nested_types,
)
return ArrowDataFrame(adf)

Expand Down Expand Up @@ -416,17 +421,17 @@ def _run(pdf: pd.DataFrame) -> Iterable[pd.DataFrame]: # pragma: no cover
_output_schema = to_snowflake_schema(output_schema)
pv = sys.version_info
python_version = f"{pv.major}.{pv.minor}"
packages = ["pandas", "cloudpickle", "fugue==0.8.7"]
packages = ["pandas", "cloudpickle", "fugue==0.8.7", "sqlalchemy"]
if self.additional_packages != "":
packages.extend(
x.strip().replace(" ", "") for x in self.additional_packages.split(",")
x.strip().replace(" ", "") for x in self.additional_packages.split(" ")
)
package_list = str(tuple(build_package_list(packages)))
udtf_create = f"""
CREATE OR REPLACE TEMP FUNCTION {udtf_name}({_input_schema})
RETURNS TABLE ({_output_schema})
LANGUAGE PYTHON
RUNTIME_VERSION={python_version}
RUNTIME_VERSION= '{python_version}'
PACKAGES={package_list}
--IMPORTS=('@fugue_staging/fugue-warehouses.zip')
HANDLER='FugueTransformer'
Expand Down Expand Up @@ -470,7 +475,8 @@ def _get_full_rand_name(self) -> str:

def __enter__(self) -> "_Uploader":
create_stage_sql = (
f"CREATE STAGE IF NOT EXISTS {self._stage}" " FILE_FORMAT=(TYPE=PARQUET)"
f"CREATE STAGE IF NOT EXISTS {self._stage}"
" FILE_FORMAT=(TYPE=PARQUET USE_LOGICAL_TYPE=TRUE BINARY_AS_TEXT=FALSE)"
)
print(create_stage_sql)
self._cursor.execute(create_stage_sql).fetchall()
Expand Down Expand Up @@ -519,7 +525,9 @@ def _map(cursor: Any, df: LocalDataFrame) -> LocalDataFrame:
file = temp_rand_str() + ".parquet"
with TemporaryDirectory() as f:
path = os.path.join(f, file)
write_parquet(df.as_arrow(), path)
write_parquet(
df.as_arrow(), path, use_deprecated_int96_timestamps=False
)
with client.cursor() as cur:
cur.execute(f"PUT file://{path} @{stage_location}").fetchall()
return ArrayDataFrame([[file]], "file:str")
Expand Down Expand Up @@ -550,7 +558,7 @@ def _copy_to_table(self, files: List[str], table: str) -> str:
f"COPY INTO {table} FROM"
f" @{self._stage}"
f" FILES = ({files_expr})"
f" FILE_FORMAT = (TYPE=PARQUET)"
f" FILE_FORMAT = (TYPE=PARQUET USE_LOGICAL_TYPE=TRUE BINARY_AS_TEXT=FALSE)"
f" MATCH_BY_COLUMN_NAME = CASE_SENSITIVE"
)
print(copy_sql)
Expand Down
2 changes: 1 addition & 1 deletion fugue_snowflake/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _to_iterable_df(self, table: IbisTable, schema: Any = None) -> LocalDataFram
def as_pandas(self) -> pd.DataFrame:
return _sf_ibis_as_pandas(self.native)

def as_arrow(self) -> pa.Table:
def as_arrow(self, type_safe: bool) -> pa.Table:
return _sf_ibis_as_arrow(self.native)


Expand Down
4 changes: 3 additions & 1 deletion fugue_snowflake/execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,9 @@ def map_dataframe( # noqa: C901
part = f"ABS(HASH({partitions_cols}) % {num})"
query = self._udtf_template(part, dot, udtf, input_cols, order_by)
else: # no partition key
if algo == "default":
if (
algo == "default"
): # TODO: snowflake doesn't support as-is map_partitions
algo = "hash"
if algo == "even":
if partition_spec.num_partitions.upper() == KEYWORD_ROWCOUNT:
Expand Down
48 changes: 25 additions & 23 deletions fugue_snowflake/ibis_snowflake/backend.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@
from typing import Iterator, Tuple
from typing import Any

import ibis.expr.datatypes as dt
from ibis.backends.snowflake import Backend as BaseBackend, SnowflakeType
from .._utils import temp_rand_str, normalize_name
import sqlalchemy.types as sat
from ibis.backends.snowflake import Backend as BaseBackend
from ibis.backends.snowflake import (
SnowflakeCompiler,
SnowflakeExprTranslator,
SnowflakeType,
)


class FugueSnowflakeType(SnowflakeType):
@classmethod
def to_ibis(cls, typ: Any, nullable: bool = True):
if isinstance(typ, (sat.BINARY, sat.VARBINARY)):
return dt.Binary(nullable=nullable)
return super().to_ibis(typ, nullable=nullable)


class FugueSnowflakeExprTranslator(SnowflakeExprTranslator):
type_mapper = FugueSnowflakeType


class FugueSnowflakeCompiler(SnowflakeCompiler):
translator_class = FugueSnowflakeExprTranslator


class Backend(BaseBackend):
def _metadata(self, query: str) -> Iterator[Tuple[str, dt.DataType]]:
with self.begin() as con:
database = normalize_name(self.current_database)
schema = normalize_name(self.current_schema)
table = normalize_name("IBIS_" + temp_rand_str())
full_name = f"{database}.{schema}.{table}"
create_sql = (
f"CREATE TEMP TABLE {full_name} AS SELECT * FROM ({query}) LIMIT 0"
)
print(create_sql)
con.exec_driver_sql(create_sql)
result = con.exec_driver_sql(f"DESC VIEW {full_name}").mappings().all()

for field in result:
name = field["name"]
type_string = field["type"]
is_nullable = field["null?"] == "Y"
yield name, SnowflakeType.from_string(type_string, nullable=is_nullable)

print("schema DONE")
compiler = FugueSnowflakeCompiler
3 changes: 2 additions & 1 deletion fugue_snowflake/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def _parse_sf(engine: str, conf: Any, **kwargs) -> ExecutionEngine:
@infer_execution_engine.candidate(
lambda objs: is_pandas_or(objs, SnowflakeDataFrame)
or any(
is_sf_ibis_table(x) or (isinstance(x, str) and x == "force_sf") for x in objs
is_sf_ibis_table(x) or _is_sf(x) or (isinstance(x, str) and x == "force_sf")
for x in objs
)
)
def _infer_sf_engine(objs: Any) -> Any:
Expand Down
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@ def get_version() -> str:
"pandas-gbq",
"google-auth",
"db-dtypes",
"ibis-framework[bigquery]",
"ibis-framework[bigquery]<9",
],
"trino": [
"fugue[sql,ibis]>=0.9",
"ibis-framework[trino]",
"ibis-framework[trino]<9",
],
"ray": ["fugue[ray]>=0.9"],
"snowflake": [
"fugue[sql,ibis]==0.9.0.dev3",
"ibis-framework[snowflake]",
"ibis-framework[snowflake]<9",
"snowflake-connector-python[pandas]",
"snowflake-snowpark-python",
"snowflake-cli-labs",
Expand Down Expand Up @@ -74,6 +74,7 @@ def get_version() -> str:
],
"ibis.backends": [
"fugue_trino = fugue_trino.ibis_trino",
"fugue_snowflake = fugue_snowflake.ibis_snowflake",
],
},
)

0 comments on commit e6c3591

Please sign in to comment.