Skip to content

Commit

Permalink
Merge pull request #233 from imagej/200-converting-between-xarray-and…
Browse files Browse the repository at this point in the history
…-dataset-should-convert-attrs-types-better

Preserve image metadata in xarray <-> Dataset conversions
  • Loading branch information
ctrueden authored Nov 11, 2022
2 parents ce9d3a5 + 2437cb1 commit 570ff57
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 36 deletions.
14 changes: 14 additions & 0 deletions src/imagej/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,20 @@ def _add_converters(self):
priority=sj.Priority.HIGH - 2,
)
)
sj.add_py_converter(
sj.Converter(
predicate=lambda obj: isinstance(obj, jc.ImageMetadata),
converter=lambda obj: convert.image_metadata_to_dict(self._ij, obj),
priority=sj.Priority.HIGH - 2,
)
)
sj.add_py_converter(
sj.Converter(
predicate=lambda obj: isinstance(obj, jc.MetadataWrapper),
converter=lambda obj: convert.metadata_wrapper_to_dict(self._ij, obj),
priority=sj.Priority.HIGH - 2,
)
)

def _format_argument(self, key, value, ij1_style):
if value is True:
Expand Down
16 changes: 16 additions & 0 deletions src/imagej/_java.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ def Throwable(self):
def ImagePlus(self):
return "ij.ImagePlus"

@blocking_import
def ImageMetadata(self):
return "io.scif.ImageMetadata"

@blocking_import
def MetadataWrapper(self):
return "io.scif.filters.MetadataWrapper"

@blocking_import
def LabelingIOService(self):
return "io.scif.labeling.LabelingIOService"
Expand Down Expand Up @@ -98,6 +106,10 @@ def AxisType(self):
def CalibratedAxis(self):
return "net.imagej.axis.CalibratedAxis"

@blocking_import
def ClassUtils(self):
return "org.scijava.util.ClassUtils"

@blocking_import
def Dimensions(self):
return "net.imglib2.Dimensions"
Expand All @@ -122,6 +134,10 @@ def ImgView(self):
def ImgLabeling(self):
return "net.imglib2.roi.labeling.ImgLabeling"

@blocking_import
def Named(self):
return "org.scijava.Named"

@blocking_import
def Util(self):
return "net.imglib2.util.Util"
Expand Down
50 changes: 49 additions & 1 deletion src/imagej/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def xarray_to_dataset(ij: "jc.ImageJ", xarr) -> "jc.Dataset":
dataset = ndarray_to_dataset(ij, xarr.values)
axes = dims._assign_axes(xarr)
dataset.setAxes(axes)
dataset.setName(xarr.name)
_assign_dataset_metadata(dataset, xarr.attrs)

return dataset
Expand Down Expand Up @@ -233,12 +234,14 @@ def java_to_xarray(ij: "jc.ImageJ", jobj) -> xr.DataArray:
xr_axes = list(permuted_rai.dim_axes)
xr_dims = list(permuted_rai.dims)
xr_attrs = sj.to_python(permuted_rai.getProperties())
xr_attrs = {sj.to_python(k): sj.to_python(v) for k, v in xr_attrs.items()}
# reverse axes and dims to match narr
xr_axes.reverse()
xr_dims.reverse()
xr_dims = dims._convert_dims(xr_dims, direction="python")
xr_coords = dims._get_axes_coords(xr_axes, xr_dims, narr.shape)
return xr.DataArray(narr, dims=xr_dims, coords=xr_coords, attrs=xr_attrs)
name = jobj.getName() if isinstance(jobj, jc.Named) else None
return xr.DataArray(narr, dims=xr_dims, coords=xr_coords, attrs=xr_attrs, name=name)


def supports_java_to_ndarray(ij: "jc.ImageJ", obj) -> bool:
Expand Down Expand Up @@ -459,6 +462,51 @@ def supports_imglabeling_to_labeling(obj):
return isinstance(obj, jc.ImgLabeling)


#######################
# Metadata converters #
#######################


def image_metadata_to_dict(ij: "jc.ImageJ", image_meta: "jc.ImageMetadata"):
"""
Converts an io.scif.ImageMetadata to a Python dict.
The components should be enough to create a new ImageMetadata.
:param ij: The ImageJ2 gateway (see imagej.init)
:param image_meta: The ImageMetadata to convert
:return: A Python dict representing image_meta
"""

# We import io.scif.Field here.
# This will prevent any conflicts with java.lang.reflect.Field.
Field = sj.jimport("io.scif.Field")

# Convert to a dict - preserve information by copying all SCIFIO fields.
#
# If info is left out of this dict, make sure that
# information is annotated with @Field upstream!
return {
str(f.getName()): ij.py.from_java(jc.ClassUtils.getValue(f, image_meta))
for f in jc.ClassUtils.getAnnotatedFields(image_meta.getClass(), Field)
}


def metadata_wrapper_to_dict(ij: "jc.ImageJ", metadata_wrapper: "jc.MetadataWrapper"):
"""
Converts a io.scif.filters.MetadataWrapper to a Python Dict.
The components should be enough to create a new MetadataWrapper
:param ij: The ImageJ2 gateway (see imagej.init)
:param metadata_wrapper: The MetadataWrapper to convert
:return: A Python dict representing metadata_wrapper
"""

return dict(
impl_cls=type(metadata_wrapper),
metadata=metadata_wrapper.unwrap(),
)


####################
# Helper functions #
####################
Expand Down
29 changes: 6 additions & 23 deletions src/imagej/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,10 @@ def rai_slice(rai, imin: Tuple, imax: Tuple, istep: Tuple):
:return: Sliced ImgLib2 RandomAccessibleInterval.
"""

# HACK: Avoid importing JLong at global scope.
# Otherwise, building the sphinx docs in doc/rtd fails with:
#
# Warning, treated as error:
# autodoc: failed to determine imagej.stack.JLong (<java class 'JLong'>) to be
# documented, the following exception was raised:
# Java Virtual Machine is not running
#
# Which can be reproduced in a REPL like this:
#
# >>> from jpype import JLong
# >>> help(JLong)
#
# So while the import here is unfortunate, it avoids the issue.
# TODO: Change to scyjava.new_jarray once we have that function.
from jpype import JArray, JLong

Views = sj.jimport("net.imglib2.view.Views")
shape = rai.shape
imin_fix = JArray(JLong)(len(shape))
imax_fix = JArray(JLong)(len(shape))
imin_fix = sj.jarray("j", [len(shape)])
imax_fix = sj.jarray("j", [len(shape)])
dim_itr = range(len(shape))

for py_dim, j_dim in zip(dim_itr, dim_itr):
Expand All @@ -50,23 +33,23 @@ def rai_slice(rai, imin: Tuple, imax: Tuple, istep: Tuple):
index = imin[py_dim]
if index < 0:
index += shape[j_dim]
imin_fix[j_dim] = JLong(index)
imin_fix[j_dim] = index
# Set maximum
if imax[py_dim] is None:
index = shape[j_dim] - 1
else:
index = imax[py_dim]
if index < 0:
index += shape[j_dim]
imax_fix[j_dim] = JLong(index)
imax_fix[j_dim] = index

istep_fix = JArray(JLong)(istep)
istep_fix = sj.jarray("j", [istep])

if _index_within_range(imin_fix, shape) and _index_within_range(imax_fix, shape):
intervaled = Views.interval(rai, imin_fix, imax_fix)
stepped = Views.subsample(intervaled, istep_fix)

# TODO: better mach NumPy squeeze behavior. See pyimagej/#1231
# TODO: better match NumPy squeeze behavior. See imagej/pyimagej#1231
dimension_reduced = Views.dropSingletonDimensions(stepped)
return dimension_reduced

Expand Down
64 changes: 58 additions & 6 deletions tests/test_image_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
import scyjava as sj
import xarray as xr

# TODO: Change to scyjava.new_jarray once we have that function.
from jpype import JArray, JInt, JLong

import imagej.dims as dims

# -- Fixtures --
Expand All @@ -18,7 +15,9 @@ def get_img(ij_fixture):
def _get_img():
# Create img
CreateNamespace = sj.jimport("net.imagej.ops.create.CreateNamespace")
dims = JArray(JLong)([1, 2, 3, 4, 5])
dims = sj.jarray("j", [5])
for i in range(len(dims)):
dims[i] = i + 1
ns = ij_fixture.op().namespace(CreateNamespace)
img = ns.img(dims)

Expand Down Expand Up @@ -77,6 +76,8 @@ def _get_nparr():

@pytest.fixture(scope="module")
def get_xarr():
name: str = "test_data_array"

def _get_xarr(option="C"):
if option == "C":
xarr = xr.DataArray(
Expand All @@ -90,6 +91,7 @@ def _get_xarr(option="C"):
"t": list(np.arange(0, 0.05, 0.01)),
},
attrs={"Hello": "World"},
name=name,
)
elif option == "F":
xarr = xr.DataArray(
Expand All @@ -102,9 +104,10 @@ def _get_xarr(option="C"):
"t": list(np.arange(0, 0.05, 0.01)),
},
attrs={"Hello": "World"},
name=name,
)
else:
xarr = xr.DataArray(np.random.rand(1, 2, 3, 4, 5))
xarr = xr.DataArray(np.random.rand(1, 2, 3, 4, 5), name=name)

return xarr

Expand All @@ -122,6 +125,7 @@ def assert_inverted_xarr_equal_to_xarr(dataset, ij_fixture, xarr):
for key in xarr.coords:
assert (xarr.coords[key] == invert_xarr.coords[key]).all()
assert xarr.attrs == invert_xarr.attrs
assert xarr.name == invert_xarr.name


def assert_ndarray_equal_to_ndarray(narr_1, narr_2):
Expand All @@ -130,7 +134,7 @@ def assert_ndarray_equal_to_ndarray(narr_1, narr_2):

def assert_ndarray_equal_to_img(img, nparr):
cursor = img.cursor()
arr = JArray(JInt)(5)
arr = sj.jarray("i", [5])
while cursor.hasNext():
y = cursor.next().get()
cursor.localize(arr)
Expand Down Expand Up @@ -251,6 +255,7 @@ def assert_xarray_equal_to_dataset(ij_fixture, xarr):

assert expected_labels == labels
assert xarr.attrs == ij_fixture.py.from_java(dataset.getProperties())
assert xarr.name == ij_fixture.py.from_java(dataset.getName())


def convert_img_and_assert_equality(ij_fixture, img):
Expand Down Expand Up @@ -292,6 +297,53 @@ def test_dataset_converts_to_xarray(ij_fixture, get_xarr):
assert_inverted_xarr_equal_to_xarr(dataset, ij_fixture, xarr)


def test_image_metadata_conversion(ij_fixture):
# Create a ImageMetadata
DefaultImageMetadata = sj.jimport("io.scif.DefaultImageMetadata")
IdentityAxis = sj.jimport("net.imagej.axis.IdentityAxis")
metadata = DefaultImageMetadata()
lengths = sj.jarray("j", [2])
lengths[0] = 4
lengths[1] = 2
metadata.populate(
"test", # name
ij_fixture.py.to_java([IdentityAxis(), IdentityAxis()]), # axes
lengths,
4, # pixelType
8, # bitsPerPixel
True, # orderCertain
True, # littleEndian
False, # indexed
False, # falseColor
True, # metadataComplete
)
# Some properties are computed on demand - since those computed values
# would not be grabbed in the map, let's set them
metadata.setThumbSizeX(metadata.getThumbSizeX())
metadata.setThumbSizeY(metadata.getThumbSizeY())
metadata.setInterleavedAxisCount(metadata.getInterleavedAxisCount())
# Convert to python
py_data = ij_fixture.py.from_java(metadata)
# Assert equality
assert py_data["thumbSizeX"] == metadata.getThumbSizeX()
assert py_data["thumbSizeY"] == metadata.getThumbSizeY()
assert py_data["pixelType"] == metadata.getPixelType()
assert py_data["bitsPerPixel"] == metadata.getBitsPerPixel()
assert py_data["axes"] == metadata.getAxes()
for axis in metadata.getAxes():
assert axis.type() in py_data["axisLengths"]
assert py_data["axisLengths"][axis.type()] == metadata.getAxisLength(axis)
assert py_data["orderCertain"] == metadata.isOrderCertain()
assert py_data["littleEndian"] == metadata.isLittleEndian()
assert py_data["indexed"] == metadata.isIndexed()
assert py_data["interleavedAxisCount"] == metadata.getInterleavedAxisCount()
assert py_data["falseColor"] == metadata.isFalseColor()
assert py_data["metadataComplete"] == metadata.isMetadataComplete()
assert py_data["thumbnail"] == metadata.isThumbnail()
assert py_data["rois"] == metadata.getROIs()
assert py_data["tables"] == metadata.getTables()


def test_rgb_image_maintains_correct_dim_order_on_conversion(ij_fixture, get_xarr):
xarr = get_xarr()
dataset = ij_fixture.py.to_java(xarr)
Expand Down
17 changes: 11 additions & 6 deletions tests/test_rai_arraylike.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
import pytest
import scyjava as sj

# TODO: Change to scyjava.new_jarray once we have that function.
from jpype import JArray, JLong

# -- Fixtures --


Expand Down Expand Up @@ -88,7 +85,10 @@ def test_slice_not_enough_dims(img):
def test_step(img):
# Create a stepped img via Views
Views = sj.jimport("net.imglib2.view.Views")
steps = JArray(JLong)([1, 1, 2])
steps = sj.jarray("j", 3)
steps[0] = 1
steps[1] = 1
steps[2] = 2
expected = Views.subsample(img, steps)
# Create a stepped img via slicing notation
actual = img[:, :, ::2]
Expand All @@ -101,7 +101,10 @@ def test_step(img):
def test_step_not_enough_dims(img):
# Create a stepped img via Views
Views = sj.jimport("net.imglib2.view.Views")
steps = JArray(JLong)([2, 1, 1])
steps = sj.jarray("j", 3)
steps[0] = 2
steps[1] = 1
steps[2] = 1
expected = Views.subsample(img, steps)
expected = Views.dropSingletonDimensions(expected)
# Create a stepped img via slicing notation
Expand All @@ -115,7 +118,9 @@ def test_slice_and_step(img):
# Create a stepped img via Views
Views = sj.jimport("net.imglib2.view.Views")
intervaled = Views.hyperSlice(img, 0, 0)
steps = JArray(JLong)([1, 2])
steps = sj.jarray("j", 2)
steps[0] = 1
steps[1] = 2
expected = Views.subsample(intervaled, steps)
# Create a stepped img via slicing notation
actual = img[:1, :, ::2]
Expand Down

0 comments on commit 570ff57

Please sign in to comment.