From 4d2fcbbd71f3dfa0b6e110a2c196786d66ea7a36 Mon Sep 17 00:00:00 2001 From: Chris Brozdowski Date: Mon, 22 Jan 2024 16:09:32 -0600 Subject: [PATCH] Fix _merge_repr for numeric data types (#786) --- src/spyglass/utils/dj_merge_tables.py | 45 +++++++++++++++------------ 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index 78c42bdc7..b835e37bc 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -245,33 +245,32 @@ def _merge_repr(cls, restriction: str = True) -> dj.expression.Union: for p in cls._merge_restrict_parts( restriction=restriction, add_invalid_restrict=False, - return_empties=True, + return_empties=False, # motivated by SpikeSortingOutput.Import ) ] - primary_attrs = list( - dict.fromkeys( # get all columns from parts - iter_chain.from_iterable([p.heading.names for p in parts]) + attr_dict = { # NULL for non-numeric, 0 for numeric + attr.name: "0" if attr.numeric else "NULL" + for attr in iter_chain.from_iterable( + part.heading.attributes.values() for part in parts ) - ) - # primary_attrs.append(cls()._reserved_sk) - query = dj.U(*primary_attrs) * parts[0].proj( # declare query - ..., # include all attributes from part 0 - **{ - a: "NULL" # add null value where part has no column - for a in primary_attrs - if a not in parts[0].heading.names - }, - ) - for part in parts[1:]: # add to declared query for each part - query += dj.U(*primary_attrs) * part.proj( - ..., + } + + def _proj_part(part): + """Project part, adding NULL/0 for missing attributes""" + return dj.U(*attr_dict.keys()) * part.proj( + ..., # include all attributes from part **{ - a: "NULL" - for a in primary_attrs - if a not in part.heading.names + k: v + for k, v in attr_dict.items() + if k not in part.heading.names }, ) + + query = _proj_part(parts[0]) # start with first part + for part in parts[1:]: # add remaining parts + query += _proj_part(part) + return query @classmethod @@ -529,6 +528,12 @@ def fetch_nwb( multi_source: bool Return from multiple parents. Default False. """ + if isinstance(self, dict): + raise ValueError("Try replacing Merge.method with Merge().method") + if restriction == True and self.restriction: + if not disable_warning: + _warn_on_restriction(self, restriction) + restriction = self.restriction return self.merge_restrict_class(restriction).fetch_nwb()