Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add placed_on_screen attribute to StimulusSet metadata #46

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions brainio/assemblies.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@
import xarray as xr
from xarray import DataArray, IndexVariable

from brainio.stimuli import StimulusSet

BRAINIO_CHUNKS = 'BRAINIO_CHUNKS'

_logger = logging.getLogger(__name__)


def is_fastpath(*args, **kwargs):
"""checks whether a set of args and kwargs would be interpreted by DataArray.__init__"""
n = 7 # maximum length of args if all arguments to DataArray are positional (as of 0.16.1)
return ("fastpath" in kwargs and kwargs["fastpath"]) or (len(args) >= n and args[n-1])
n = 7 # maximum length of args if all arguments to DataArray are positional (as of 0.16.1)
return ("fastpath" in kwargs and kwargs["fastpath"]) or (len(args) >= n and args[n - 1])


class DataPoint(object):
Expand Down Expand Up @@ -86,7 +88,7 @@ def __init__(self, values):

def __eq__(self, other):
return len(self.values) == len(other.values) and \
all(v1 == v2 for v1, v2 in zip(self.values, other.values))
all(v1 == v2 for v1, v2 in zip(self.values, other.values))

def __lt__(self, other):
return self.values < other.values
Expand Down Expand Up @@ -337,28 +339,30 @@ def get_metadata_before_2022_06(assembly, dims=None, names_only=False, include_c
"""
Return coords and/or indexes or index levels from an assembly, yielding either `name` or `(name, dims, values)`.
"""

def what(name, dims, values, names_only):
if names_only:
return name
else:
return name, dims, values

if dims is None:
dims = assembly.dims + (None,) # all dims plus dimensionless coords
dims = assembly.dims + (None,) # all dims plus dimensionless coords
for name in assembly.coords.variables:
values = assembly.coords.variables[name]
is_subset = values.dims and (set(values.dims) <= set(dims))
is_dimless = (not values.dims) and None in dims
if is_subset or is_dimless:
is_index = isinstance(values, IndexVariable)
if is_index:
if values.level_names: # it's a MultiIndex
if values.level_names: # it's a MultiIndex
if include_multi_indexes:
yield what(name, values.dims, values.values, names_only)
if include_levels:
for level in values.level_names:
level_values = assembly.coords[level]
yield what(level, level_values.dims, level_values.values, names_only)
else: # it's an Index
else: # it's an Index
if include_indexes:
yield what(name, values.dims, values.values, names_only)
else:
Expand All @@ -367,17 +371,19 @@ def what(name, dims, values, names_only):


def get_metadata_after_2022_06(assembly, dims=None, names_only=False, include_coords=True,
include_indexes=True, include_multi_indexes=False, include_levels=True):
include_indexes=True, include_multi_indexes=False, include_levels=True):
"""
Return coords and/or indexes or index levels from an assembly, yielding either `name` or `(name, dims, values)`.
"""

def what(name, dims, values, names_only):
if names_only:
return name
else:
return name, dims, values

if dims is None:
dims = assembly.dims + (None,) # all dims plus dimensionless coords
dims = assembly.dims + (None,) # all dims plus dimensionless coords
for name, values in assembly.coords.items():
none_but_keep = (not values.dims) and None in dims
shared = not (set(values.dims).isdisjoint(set(dims)))
Expand Down Expand Up @@ -407,7 +413,7 @@ def get_metadata(assembly, dims=None, names_only=False, include_coords=True,
include_indexes, include_multi_indexes, include_levels)
except TypeError as e:
yield from get_metadata_before_2022_06(assembly, dims, names_only, include_coords,
include_indexes, include_multi_indexes, include_levels)
include_indexes, include_multi_indexes, include_levels)


def coords_for_dim(assembly, dim):
Expand All @@ -434,7 +440,8 @@ def gather_indexes(assembly):
"""This is only necessary as long as xarray cannot persist MultiIndex to netCDF. """
coords_d = {}
for dim in assembly.dims:
coord_names = list(get_metadata(assembly, dims=(dim,), names_only=True, include_indexes=False, include_levels=False))
coord_names = list(
get_metadata(assembly, dims=(dim,), names_only=True, include_indexes=False, include_levels=False))
if coord_names:
coords_d[dim] = coord_names
if coords_d:
Expand All @@ -457,7 +464,7 @@ def load(self):
try:
import dask
result = xr.open_dataarray(self.file_path, group=self.group, chunks=chunks)
except ModuleNotFoundError as e:
except ModuleNotFoundError:
result = xr.open_dataarray(self.file_path, group=self.group)
result = self.correct_stimulus_id_name(result)
result = self.assembly_class(data=result)
Expand Down Expand Up @@ -501,7 +508,7 @@ def load(self):
result = self.merge_stimulus_set_meta(result, self.stimulus_set)
return result

def merge_stimulus_set_meta(self, assy, stimulus_set):
def merge_stimulus_set_meta(self, assy: DataAssembly, stimulus_set: StimulusSet) -> DataAssembly:
dim_name, index_column = "presentation", "stimulus_id"
assy = assy.reset_index(list(assy.indexes))
df_of_coords = pd.DataFrame(coords_for_dim(assy, dim_name))
Expand Down Expand Up @@ -539,6 +546,3 @@ def load(self):
exc_info=True
)
return result



3 changes: 2 additions & 1 deletion brainio/stimuli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

class StimulusSet(pd.DataFrame):
# http://pandas.pydata.org/pandas-docs/stable/development/extending.html#subclassing-pandas-data-structures
_metadata = pd.DataFrame._metadata + ["identifier", "get_stimulus", 'get_loader_class', "stimulus_paths", "from_files"]
_metadata = pd.DataFrame._metadata + ["identifier", "get_stimulus", "get_loader_class",
"stimulus_paths", "from_files", "placed_on_screen"]

@property
def _constructor(self):
Expand Down