Skip to content

Commit

Permalink
Fix _merge_repr for numeric data types (#786)
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 authored Jan 22, 2024
1 parent 688f0ca commit 4d2fcbb
Showing 1 changed file with 25 additions and 20 deletions.
45 changes: 25 additions & 20 deletions src/spyglass/utils/dj_merge_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,33 +245,32 @@ def _merge_repr(cls, restriction: str = True) -> dj.expression.Union:
for p in cls._merge_restrict_parts(
restriction=restriction,
add_invalid_restrict=False,
return_empties=True,
return_empties=False, # motivated by SpikeSortingOutput.Import
)
]

primary_attrs = list(
dict.fromkeys( # get all columns from parts
iter_chain.from_iterable([p.heading.names for p in parts])
attr_dict = { # NULL for non-numeric, 0 for numeric
attr.name: "0" if attr.numeric else "NULL"
for attr in iter_chain.from_iterable(
part.heading.attributes.values() for part in parts
)
)
# primary_attrs.append(cls()._reserved_sk)
query = dj.U(*primary_attrs) * parts[0].proj( # declare query
..., # include all attributes from part 0
**{
a: "NULL" # add null value where part has no column
for a in primary_attrs
if a not in parts[0].heading.names
},
)
for part in parts[1:]: # add to declared query for each part
query += dj.U(*primary_attrs) * part.proj(
...,
}

def _proj_part(part):
"""Project part, adding NULL/0 for missing attributes"""
return dj.U(*attr_dict.keys()) * part.proj(
..., # include all attributes from part
**{
a: "NULL"
for a in primary_attrs
if a not in part.heading.names
k: v
for k, v in attr_dict.items()
if k not in part.heading.names
},
)

query = _proj_part(parts[0]) # start with first part
for part in parts[1:]: # add remaining parts
query += _proj_part(part)

return query

@classmethod
Expand Down Expand Up @@ -529,6 +528,12 @@ def fetch_nwb(
multi_source: bool
Return from multiple parents. Default False.
"""
if isinstance(self, dict):
raise ValueError("Try replacing Merge.method with Merge().method")
if restriction == True and self.restriction:
if not disable_warning:
_warn_on_restriction(self, restriction)
restriction = self.restriction

return self.merge_restrict_class(restriction).fetch_nwb()

Expand Down

0 comments on commit 4d2fcbb

Please sign in to comment.