diff --git a/src/gateway/converter/spark_functions.py b/src/gateway/converter/spark_functions.py index b0701e2..56bf436 100644 --- a/src/gateway/converter/spark_functions.py +++ b/src/gateway/converter/spark_functions.py @@ -2,6 +2,7 @@ """Provides the mapping of Spark functions to Substrait.""" import dataclasses +from backends.backend_options import BackendEngine from gateway.converter.conversion_options import ConversionOptions from substrait.gen.proto import algebra_pb2, type_pb2 @@ -460,11 +461,27 @@ def __lt__(self, obj) -> bool: i64=type_pb2.Type.I64( nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), } +SPARK_SUBSTRAIT_MAPPING_FOR_DUCKDB = { + 'struct': ExtensionFunction( + '/functions_structs.yaml', 'struct_pack:any_str', type_pb2.Type( + i64=type_pb2.Type.I64( + nullability=type_pb2.Type.Nullability.NULLABILITY_REQUIRED))), + **SPARK_SUBSTRAIT_MAPPING +} + + +def find_mapping(options:ConversionOptions) -> dict[str, ExtensionFunction]: + match options.backend.backend: + case BackendEngine.DUCKDB: + return SPARK_SUBSTRAIT_MAPPING_FOR_DUCKDB + case _: + return SPARK_SUBSTRAIT_MAPPING def lookup_spark_function(name: str, options: ConversionOptions) -> ExtensionFunction: """Return a Substrait function given a spark function name.""" - definition = SPARK_SUBSTRAIT_MAPPING.get(name) + mapping = find_mapping(options) + definition = mapping.get(name) if definition is None: raise ValueError(f'Function {name} not found in the Spark to Substrait mapping table.') if not options.return_names_with_types: