Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Projector implementation #585

Merged
merged 6 commits into from
Aug 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 190 additions & 0 deletions databroker/projector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import xarray
from importlib import import_module

from .core import BlueskyRun


class ProjectionError(Exception):
pass


def get_run_projection(run: BlueskyRun, projection_name: str = None):
"""Finds a projection in the run.
If projection_name is provided, searches through the projections in the run
to find a match.

Otherwise, looks in the run to see if there is only one projection. If so, returns it.

Parameters
----------
run : BlueskyRun
Run to investigate for a projection
projection_name : str, optional
name of the projection to look for, by default None

Returns
-------
dict
returns a projection dictionary, or None of not found

Raises
------
KeyError
If the a projection_name is specified and there is more than one
projection in the run with that name
"""

if projection_name is not None:
projections = [projection for projection in run.metadata['start']['projections']
if projection.get('name') == projection_name]
if len(projections) > 1:
raise KeyError("Multiple projections of name {projection_name} found")
if len(projections) == 1:
return projections[0]
if len(projections) == 0:
return None

if 'projections' in run.metadata['start'] and len(run.metadata['start']['projections']) == 1:
return run.metadata['start']['projections'][0]

return None


def get_calculated_value(run: BlueskyRun, key: str, mapping: dict):
"""Calls and returns the callable from the calculated projection mapping.

It is ancticipated that the return will be
and xarray.DataArray.

This should be expressed in the familiar 'module:func' syntax borrowed from python entry-points.

An example implementation of a calculated field projection entry:

'/projection/key': {
"type": "calculated",
"callable": "foo.bar:really_fun",
"args": ['arg1'], "kwargs": {"foo": "bar"}}

And a corresponding function implementation might be:

def really_fun(run, *args, **kwargs)"
# args will be ['arg1']
# kwargs will be {"foo": "bar"}
# for this calculated field
return xarray.DataArray[[1, 2, 3]]


Parameters
----------
run : BlueskyRun
run which can be used for the calcuation
key : str
key name for this projection
mapping : dict
full contents of this projection

Returns
-------
any
result of calling the method specified in the calcated field in the projection

Raises
------
ProjectionError
[description]
"""
callable_name = mapping['callable']
try:
module_name, function_name = callable_name.split(":")
module = import_module(module_name)
callable_func = getattr(module, function_name)
except ProjectionError as e:
raise ProjectionError('Error importing callable {function_name}', e)

calc_args = mapping['args']
calc_kwargs = mapping['kwargs']
return callable_func(run, *calc_args, **calc_kwargs)


def project_xarray(run: BlueskyRun, *args, projection=None, projection_name=None, **kwargs):
"""Produces an xarray Dataset by projecting the provided run. Selects projection based on
logic of get_run_projection().


Projections come with multiple types: linked, and caclulated. Calculated fields are only supported
in the data (not at the top-level attrs).

Calculated fields in projections schema contain a callable field. This should be expressed in
the familiar 'module:func' syntax borrowed from python entry-points.

All projections with "location"="configuration" will look in the start document
for metadata. Each field will be added to the return Dataset's attrs dictionary keyed
on projection key.


All projections with "location"="event" will look for a field in a stream.

Parameters
----------
run : BlueskyRun
run to project
projection_name : str, optional
name of a projection to select in the run, by default None
projection : dict, optional
projection not from the run to use, by default None

Returns
-------
xarray.Dataset
The return Dataset will contain:
- single value data (typically from the run start) in the return Dataset's attrs dict, keyed
on the projection key. These are projections marked "location": "configuration"

- multi-value data (typically from a stream). Keys for the dict-like xarray.Dataset match keys
in the passed-in projection. These are projections with "location": "linked"

Raises
------
ProjectionError
"""
try:
if projection is None:
projection = get_run_projection(run, projection_name)
if projection is None:
raise ProjectionError("Projection could not be found")

attrs = {} # will populate the return Dataset attrs field
data_vars = {} # will populate the return Dataset DataArrays
for field_key, mapping in projection['projection'].items():
# go through each projection
projection_type = mapping['type']
projection_location = mapping.get('location')
projection_data = None
projection_linked_field = mapping.get('field')

# single value data that will go in the top
# dataset's attributes
if projection_location == 'configuration':
attrs[field_key] = run.metadata['start'][projection_linked_field]
continue

# added to return Dataset in data_vars dict
if projection_type == "calculated":
data_vars[field_key] = get_calculated_value(run, field_key, mapping)
continue

# added to return Dataset in data_vars dict
if projection_location == 'event':
projection_stream = mapping.get('stream')
if projection_stream is None:
raise ProjectionError(f'stream missing for event projection: {field_key}')
data_vars[field_key] = run[projection_stream].to_dask()[projection_linked_field]

elif projection_location == 'configuration':
attrs[field_key] = projection_data
else:
raise KeyError(f'Unknown location: {projection_location} in projection.')

except Exception as e:
raise ProjectionError('Error projecting run') from e
return xarray.Dataset(data_vars, attrs=attrs)
169 changes: 169 additions & 0 deletions databroker/tests/test_projector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import pytest
import xarray
from databroker.core import BlueskyRun

from ..projector import get_run_projection, project_xarray, ProjectionError

NEX_IMAGE_FIELD = '/entry/instrument/detector/data'
NEX_ENERGY_FIELD = '/entry/instrument/monochromator/energy'
NEX_SAMPLE_NAME_FIELD = '/entry/sample/name'
MOCK_IMAGE = xarray.DataArray([[1, 2], [3, 4]])
BEAMLINE_ENERGY_VALS = [1, 2, 3, 4, 5]
OTHER_VALS = [-1, -2, -3, -4, -5]
CCD = [MOCK_IMAGE+1, MOCK_IMAGE+2, MOCK_IMAGE+3, MOCK_IMAGE+4, MOCK_IMAGE+5]
good_projection = [{
"name": "nxsas",
"version": "2020.1",
"configuration": {"name": "RSoXS"},
"projection": {
NEX_SAMPLE_NAME_FIELD: {"type": "linked", "location": "configuration", "field": "sample"},
NEX_IMAGE_FIELD: {"type": "linked", "location": "event", "stream": "primary", "field": "ccd"},
NEX_ENERGY_FIELD: {"type": "linked", "location": "event", "stream": "primary",
"field": "beamline_energy"},
}
}]

bad_location = [{
"name": "nxsas",
"version": "2020.1",
"configuration": {"name": "RSoXS"},
"projection": {
NEX_SAMPLE_NAME_FIELD: {"type": "linked", "location": "i_dont_exist", "field": "sample"},

}
}]

bad_stream = [{
"name": "nxsas",
"version": "2020.1",
"configuration": {"name": "RSoXS"},
"projection": {
NEX_SAMPLE_NAME_FIELD: {"type": "linked", "location": "configuration", "field": "sample"},
NEX_IMAGE_FIELD: {"type": "linked", "location": "event", "stream": "i_dont_exist", "field": "ccd"},

}
}]

bad_field = [{
"name": "nxsas",
"version": "2020.1",
"configuration": {"name": "RSoXS"},
"projection": {
NEX_SAMPLE_NAME_FIELD: {"type": "linked", "location": "configuration", "field": "sample"},
NEX_IMAGE_FIELD: {"type": "linked", "location": "event", "stream": "primary", "field": "i_dont_exist"},

}
}]

projections_same_name = [
{
"name": "nxsas"
},
{
"name": "nxsas"
}
]


class MockStream():
def __init__(self, metadata):
self.metadata = metadata
data_vars = {
'beamline_energy': ('time', BEAMLINE_ENERGY_VALS),
'ccd': (('time', 'dim_0', 'dim_1'), CCD)
}
self.dataset = xarray.Dataset(data_vars)
self.to_dask_counter = 0

def to_dask(self):
# This enables us to test that the to_dask function is called
# the appropriate number of times.
# It would be better if we could actually return the dataset as a dask dataframe
# However, for some reason this won't let us access the arrays
# by numeric index and will throw an error
self.to_dask_counter += 1
return self.dataset


class MockRun():
def __init__(self, projections=[], sample='',):
self.metadata = {
'start': {
'sample': sample,
'projections': projections
},
'stop': {}
}

self.primary = MockStream(self.metadata)

def __getitem__(self, key):
if key == 'primary':
return self.primary
raise KeyError(f'Key: {key}, does not exist')


def make_mock_run(projections, sample):
return MockRun(projections, sample)


def dont_panic(run: BlueskyRun, *args, **kwargs):
# TODO test that args and kwargs are passed
return xarray.DataArray([42, 42, 42, 42, 42])


def test_calculated_projections():
calculated_projection = [{
"name": "nxsas",
"version": "2020.1",
"configuration": {"name": "RSoXS"},
"projection": {
'/entry/event/computed': {
"type": "calculated",
"callable": "databroker.tests.test_projector:dont_panic",
"args": ['trillian'], "kwargs": {"ford": "prefect"}}
}
}]

mock_run = make_mock_run(calculated_projection, 'garggle_blaster')
dataset = project_xarray(mock_run)
comparison = dataset['/entry/event/computed'] == [42, 42, 42, 42, 42]
assert comparison.all()


def test_find_projection_in_run():
mock_run = make_mock_run(good_projection, 'one_ring')
assert get_run_projection(mock_run, projection_name="nxsas") == good_projection[0]
assert get_run_projection(mock_run, projection_name="vogons") is None
assert get_run_projection(mock_run) == good_projection[0] # only one projection in run so choose it
with pytest.raises(KeyError):
mock_run = make_mock_run(projections_same_name, 'one_ring')
get_run_projection(mock_run, projection_name="nxsas")


def test_unknown_location():
mock_run = make_mock_run(bad_location, 'one_ring')
with pytest.raises(ProjectionError):
projector = project_xarray(mock_run)
projector.project(bad_location[0])


def test_nonexistent_stream():
mock_run = make_mock_run(bad_stream, 'one_ring')
with pytest.raises(ProjectionError):
projector = project_xarray(mock_run)
projector.project(bad_stream[0])


def test_projector():
mock_run = make_mock_run(good_projection, 'one_ring')
dataset = project_xarray(mock_run)
# Ensure that the to_dask function was called on both
# energy and image datasets
assert mock_run['primary'].to_dask_counter == 2
assert dataset.attrs[NEX_SAMPLE_NAME_FIELD] == mock_run.metadata['start']['sample']
for idx, energy in enumerate(dataset[NEX_ENERGY_FIELD]):
assert energy == mock_run['primary'].dataset['beamline_energy'][idx]
for idx, image in enumerate(dataset[NEX_IMAGE_FIELD]):
comparison = image == mock_run['primary'].dataset['ccd'][idx] # xarray of comparison results
assert comparison.all() # False if comparision does not contain all True