Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NWB Inspector Tests #1121

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
205 changes: 143 additions & 62 deletions neo/io/nwbio.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
from pynwb.misc import AnnotationSeries
from pynwb import image
from pynwb.image import ImageSeries
from pynwb.file import Subject
from pynwb.epoch import TimeIntervals
from pynwb.spec import NWBAttributeSpec, NWBDatasetSpec, NWBGroupSpec, NWBNamespace, \
NWBNamespaceBuilder
from pynwb.device import Device
Expand All @@ -55,6 +57,13 @@
except ImportError:
have_pynwb = False

try:
import nwbinspector
JuliaSprenger marked this conversation as resolved.
Show resolved Hide resolved
from nwbinspector import inspect_nwb, check_regular_timestamps
have_nwbinspector = True
except ImportError:
have_nwbinspector = False

# hdmf imports
try:
from hdmf.spec import (LinkSpec, GroupSpec, DatasetSpec, SpecNamespace,
Expand Down Expand Up @@ -244,6 +253,8 @@ def __init__(self, filename, mode='r'):
raise Exception("Please install the pynwb package to use NWBIO")
if not have_hdmf:
raise Exception("Please install the hdmf package to use NWBIO")
if not have_nwbinspector:
raise Exception("Please install the nwbinspector package to use NWBIO")
BaseIO.__init__(self, filename=filename)
self.filename = filename
self.blocks_written = 0
Expand Down Expand Up @@ -275,6 +286,9 @@ def read_all_blocks(self, lazy=False, **kwargs):
if "file_create_date" in self.global_block_metadata:
self.global_block_metadata["file_datetime"] = self.global_block_metadata[
"rec_datetime"]
if "subject" in self.global_block_metadata:
self.global_block_metadata["subject"] = self.global_block_metadata[
"subject"]

self._blocks = {}
self._read_acquisition_group(lazy=lazy)
Expand Down Expand Up @@ -352,8 +366,6 @@ def _read_timeseries_group(self, group_name, lazy):
except JSONDecodeError:
# For NWB files created with other applications, we put everything in a single
# segment in a single block
# todo: investigate whether there is a reliable way to create multiple segments,
# e.g. using Trial information
block_name = "default"
segment_name = "default"
else:
Expand Down Expand Up @@ -441,8 +453,14 @@ def write_all_blocks(self, blocks, **kwargs):
raise Exception("Writing to NWB requires an annotation 'session_start_time'")
self.annotations = {"rec_datetime": "rec_datetime"}
self.annotations["rec_datetime"] = blocks[0].rec_datetime
# todo: handle subject
self.annotations = {"subject": "subject"}
JuliaSprenger marked this conversation as resolved.
Show resolved Hide resolved
nwbfile = NWBFile(**annotations)
if "subject" not in annotations:
nwbfile.subject = Subject(subject_id="subject_id",
age="P0D", # Period x days old
description="no description",
species="Mus musculus", # by default
sex="U")
JuliaSprenger marked this conversation as resolved.
Show resolved Hide resolved
assert self.nwb_file_mode in ('w',) # possibly expand to 'a'ppend later
if self.nwb_file_mode == "w" and os.path.exists(self.filename):
os.remove(self.filename)
Expand All @@ -458,15 +476,16 @@ def write_all_blocks(self, blocks, **kwargs):
'block', 'the name of the Neo Block to which the SpikeTrain belongs')

if sum(statistics(block)["Epoch"]["count"] for block in blocks) > 0:
nwbfile.add_epoch_column('_name', 'the name attribute of the Epoch')
# nwbfile.add_epoch_column('_description', 'the description attribute of the Epoch')
nwbfile.add_epoch_column(
'segment', 'the name of the Neo Segment to which the Epoch belongs')
nwbfile.add_epoch_column('block',
'the name of the Neo Block to which the Epoch belongs')
nwbfile.add_trial_column('segment', 'name of the Segment to which the Epoch belongs')
nwbfile.add_trial_column('block', 'name of the Block to which the Epoch belongs')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if nwb is supporting the epoch concept, what is the reason to switch to trials now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. You are right. As it supports the epoch concept, it is not necessary to have the trials function.
I put it back as before.


arr = [[], []] # epoch array for ascending t_start and t_stop
for i, block in enumerate(blocks):
self.write_block(nwbfile, block)
block_name = block.name
self.write_block(nwbfile, block, arr)
arr2 = np.sort(arr)
self._write_epoch(nwbfile, arr2, block)
JuliaSprenger marked this conversation as resolved.
Show resolved Hide resolved

io_nwb.write(nwbfile)
io_nwb.close()

Expand All @@ -475,7 +494,37 @@ def write_all_blocks(self, blocks, **kwargs):
if errors:
raise Exception(f"Errors found when validating {self.filename}")

def write_block(self, nwbfile, block, **kwargs):
# NWBInspector : Inspect NWB files for compliance with NWB Best Practices.
results_generator = inspect_nwb(nwbfile_path=self.filename)
for message in results_generator:
if message.importance._name_ == "CRITICAL":
print("message.importance = ", message.importance)
print("Potentially incorrect data")
print(message.message)
print("message.check_function_name = ", message.check_function_name)
print("message.object_type = ", message.object_type)
print("message.object_name = ", message.object_name)
print("----------------------")
if message.importance._name_ == "BEST_PRACTICE_VIOLATION":
print("message.importance = ", message.importance)
print("Very suboptimal data representation")
print(message.message)
print("message.check_function_name = ", message.check_function_name)
print("message.object_type = ", message.object_type)
print("message.object_name = ", message.object_name)
print("----------------------")
if message.importance._name_ == "BEST_PRACTICE_SUGGESTION":
print("message.importance = ", message.importance)
print("Improvable data representation")
print(message.message)
print("message.check_function_name = ", message.check_function_name)
print("message.object_type = ", message.object_type)
print("message.object_name = ", message.object_name)
print("----------------------")

io_nwb.close()

def write_block(self, nwbfile, block, arr, **kwargs):
"""
Write a Block to the file
:param block: Block to be written
Expand All @@ -485,10 +534,11 @@ def write_block(self, nwbfile, block, **kwargs):
if not block.name:
block.name = "block%d" % self.blocks_written
for i, segment in enumerate(block.segments):
segment.name = "%s : segment%d" % (block.name, i)
assert segment.block is block
if not segment.name:
segment.name = "%s : segment%d" % (block.name, i)
self._write_segment(nwbfile, segment, electrodes)
self._write_segment(nwbfile, segment, electrodes, arr)
self.blocks_written += 1

def _write_electrodes(self, nwbfile, block):
Expand All @@ -512,8 +562,7 @@ def _write_electrodes(self, nwbfile, block):
)
return electrodes

def _write_segment(self, nwbfile, segment, electrodes):
# maybe use NWB trials to store Segment metadata?
def _write_segment(self, nwbfile, segment, electrodes, arr):
for i, signal in enumerate(
chain(segment.analogsignals, segment.irregularlysampledsignals)):
assert signal.segment is segment
Expand Down Expand Up @@ -541,8 +590,8 @@ def _write_segment(self, nwbfile, segment, electrodes):

for i, epoch in enumerate(segment.epochs):
if not epoch.name:
epoch.name = "%s : epoch%d" % (segment.name, i)
self._write_epoch(nwbfile, epoch)
epoch_name = "%s : epoch%d" % (segment.name, i)
self._write_manage_epoch(nwbfile, segment, epoch, arr)

def _write_signal(self, nwbfile, signal, electrodes):
hierarchy = {'block': signal.segment.block.name, 'segment': signal.segment.name}
Expand All @@ -563,23 +612,21 @@ def _write_signal(self, nwbfile, signal, electrodes):
units = signal.units
if isinstance(signal, AnalogSignal):
sampling_rate = signal.sampling_rate.rescale("Hz")
tS = timeseries_class(
name=signal.name,
starting_time=time_in_seconds(signal.t_start),
data=signal,
unit=units.dimensionality.string,
rate=float(sampling_rate),
comments=json.dumps(hierarchy),
**additional_metadata)
tS = timeseries_class(name=signal.name,
starting_time=time_in_seconds(signal.t_start),
data=signal,
unit=units.dimensionality.string,
rate=float(sampling_rate),
comments=json.dumps(hierarchy),
**additional_metadata)
# todo: try to add array_annotations via "control" attribute
elif isinstance(signal, IrregularlySampledSignal):
tS = timeseries_class(
name=signal.name,
data=signal,
unit=units.dimensionality.string,
timestamps=signal.times.rescale('second').magnitude,
comments=json.dumps(hierarchy),
**additional_metadata)
tS = timeseries_class(name=signal.name,
data=signal,
unit=units.dimensionality.string,
timestamps=signal.times.rescale('second').magnitude,
comments=json.dumps(hierarchy),
**additional_metadata)
else:
raise TypeError(
"signal has type {0}, should be AnalogSignal or IrregularlySampledSignal".format(
Expand Down Expand Up @@ -611,23 +658,54 @@ def _write_spiketrain(self, nwbfile, spiketrain):

def _write_event(self, nwbfile, event):
hierarchy = {'block': event.segment.block.name, 'segment': event.segment.name}
tS_evt = AnnotationSeries(
name=event.name,
data=event.labels,
timestamps=event.times.rescale('second').magnitude,
description=event.description or "",
comments=json.dumps(hierarchy))
# if constant timestamps
timestamps = event.times.rescale('second').magnitude
if any(timestamps) == any(timestamps):
tS_evt = TimeSeries(name=event.name,
data=event.labels,
starting_time=0.0,
rate=0.01,
unit=str(event.units),
description=event.description or "",
comments=json.dumps(hierarchy))
else:
tS_evt = TimeSeries(name=event.name,
data=event.labels,
timestamps=event.times.rescale('second').magnitude,
unit=str(event.units),
description=event.description or "",
comments=json.dumps(hierarchy))

nwbfile.add_acquisition(tS_evt)
return tS_evt

def _write_epoch(self, nwbfile, epoch):
def _write_manage_epoch(self, nwbfile, segment, epoch, arr):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method could benefit from having a docstring. What exactly is arr here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arr is an epoch array with t_start and t_stop sorted in ascending order.
I don't think it's necessary to create a docstring for this function. It's just an intermediate step in order to respect NWBInspector.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you sort t_start and t_stop independently, doesn't this corrupt data in case one epoch duration is contained in another?

for t_start, duration, label in zip(epoch.rescale('s').magnitude,
epoch.durations.rescale('s').magnitude,
epoch.labels):
nwbfile.add_epoch(t_start, t_start + duration, [label], [],
_name=epoch.name,
segment=epoch.segment.name,
block=epoch.segment.block.name)
epoch.labels,
):
for j in [label]:
t_stop = t_start + duration
seg_name = "%s %s" % (epoch.segment.name, label)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here occurs the same error as with the events earlier. You have to keep a reference to the segment from before the proxy object is loaded.

bl_name = "%s %s" % (epoch.segment.block.name, label)
epoch_name = "%s %s" % (segment.name, j)

arr[0].append(t_start)
arr[1].append(t_stop)

def _write_epoch(self, nwbfile, arr2, block):
for i in range(len(arr2[0])):
t_start = arr2[0][i]
t_stop = arr2[1][i]
for k in block.segments:
segment_name = k.name
nwbfile.add_trial(start_time=t_start,
stop_time=t_stop,
tags=[" "],
timeseries=[],
segment=segment_name,
block=block.name,
)
return nwbfile.epochs


Expand All @@ -644,6 +722,9 @@ def __init__(self, timeseries, nwb_group):
self.units = timeseries.unit
if timeseries.conversion:
self.units = _recompose_unit(timeseries.unit, timeseries.conversion)
check_timestamps = check_regular_timestamps(timeseries)
if check_timestamps is not None:
timeseries.starting_time = 0.0
if timeseries.starting_time is not None:
self.t_start = timeseries.starting_time * pq.s
else:
Expand Down Expand Up @@ -712,27 +793,27 @@ def load(self, time_slice=None, strict_slicing=True):
time_slice, strict_slicing=strict_slicing)
signal = self._timeseries.data[i_start: i_stop]
if self.sampling_rate is None:
return IrregularlySampledSignal(
self._timeseries.timestamps[i_start:i_stop] * pq.s,
signal,
units=self.units,
t_start=sig_t_start,
sampling_rate=self.sampling_rate,
name=self.name,
description=self.description,
array_annotations=None,
**self.annotations) # todo: timeseries.control / control_description
return IrregularlySampledSignal(self._timeseries.timestamps[i_start:i_stop] * pq.s,
signal=signal,
units=self.units,
t_start=sig_t_start,
sampling_rate=self.sampling_rate,
name=self.name,
description=self.description,
array_annotations=None,
**self.annotations)
# todo: timeseries.control / control_description

else:
return AnalogSignal(
signal,
units=self.units,
t_start=sig_t_start,
sampling_rate=self.sampling_rate,
name=self.name,
description=self.description,
array_annotations=None,
**self.annotations) # todo: timeseries.control / control_description
return AnalogSignal(signal,
units=self.units,
t_start=sig_t_start,
sampling_rate=self.sampling_rate,
name=self.name,
description=self.description,
array_annotations=None,
**self.annotations)
# todo: timeseries.control / control_description


class EventProxy(BaseEventProxy):
Expand Down
Loading