Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug with 3d data test in transform_to_data_extent #787

Merged
merged 4 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/spatialdata/_core/operations/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from xarray import DataArray, DataTree

from spatialdata.models import SpatialElement
from spatialdata.models import SpatialElement, get_axes_names, get_spatial_axes

if TYPE_CHECKING:
from spatialdata._core.spatialdata import SpatialData
Expand Down Expand Up @@ -114,12 +114,13 @@ def transform_to_data_extent(
}

for _, element_name, element in sdata_raster.gen_spatial_elements():
element_axes = get_spatial_axes(get_axes_names(element))
if isinstance(element, DataArray | DataTree):
rasterized = rasterize(
element,
axes=data_extent_axes,
min_coordinate=[data_extent[ax][0] for ax in data_extent_axes],
max_coordinate=[data_extent[ax][1] for ax in data_extent_axes],
axes=element_axes,
min_coordinate=[data_extent[ax][0] for ax in element_axes],
max_coordinate=[data_extent[ax][1] for ax in element_axes],
target_coordinate_system=coordinate_system,
target_unit_to_pixels=None,
target_width=target_width,
Expand Down
5 changes: 4 additions & 1 deletion src/spatialdata/_core/operations/rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,10 @@ def rasterize_images_labels(
target_coordinate_system=target_coordinate_system,
)

half_pixel_offset = Translation([0.5, 0.5, 0.5], axes=("z", "y", "x"))
if "z" in spatial_axes:
half_pixel_offset = Translation([0.5, 0.5, 0.5], axes=("z", "y", "x"))
else:
half_pixel_offset = Translation([0.5, 0.5], axes=("y", "x"))
sequence = Sequence(
[
# half_pixel_offset.inverse(),
Expand Down
54 changes: 34 additions & 20 deletions tests/core/operations/test_spatialdata_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
from spatialdata._core.data_extent import are_extents_equal, get_extent
from spatialdata._core.operations._utils import transform_to_data_extent
from spatialdata._core.spatialdata import SpatialData
from spatialdata._types import ArrayLike
from spatialdata.datasets import blobs
from spatialdata.models import (
Image2DModel,
Labels2DModel,
PointsModel,
ShapesModel,
TableModel,
get_model,
get_table_keys,
)
from spatialdata.testing import assert_elements_dict_are_identical, assert_spatial_data_objects_are_identical
Expand Down Expand Up @@ -490,32 +490,46 @@ def test_transform_to_data_extent(full_sdata: SpatialData, maintain_positioning:
"poly",
]
full_sdata = full_sdata.subset(elements)
points = full_sdata["points_0"].compute()
points["z"] = points["x"]
points = PointsModel.parse(points)
full_sdata["points_0_3d"] = points
sdata = transform_to_data_extent(full_sdata, "global", target_width=1000, maintain_positioning=maintain_positioning)

matrices = []
for el in sdata._gen_spatial_element_values():
first_a: ArrayLike | None = None
for _, name, el in sdata.gen_spatial_elements():
t = get_transformation(el, to_coordinate_system="global")
assert isinstance(t, BaseTransformation)
a = t.to_affine_matrix(input_axes=("x", "y", "z"), output_axes=("x", "y", "z"))
matrices.append(a)
a = t.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))
if first_a is None:
first_a = a
else:
# we are not pixel perfect because of this bug: https://github.com/scverse/spatialdata/issues/165
if maintain_positioning and name in ["points_0_3d", "points_0", "poly", "circles", "multipoly"]:
# Again, due to the "pixel perfect" bug, the 0.5 translation forth and back in the z axis that is added
# by rasterize() (like the one in the example belows), amplifies the error also for x and y beyond the
# rtol threshold below. So, let's skip that check and to an absolute check up to 0.5 (due to the
# half-pixel offset).
# Sequence
# Translation (z, y, x)
# [-0.5 -0.5 -0.5]
# Scale (y, x)
# [0.17482681 0.17485125]
# Translation (y, x)
# [ -3.13652607 -164. ]
# Translation (z, y, x)
# [0.5 0.5 0.5]
assert np.allclose(a, first_a, atol=0.5)
else:
assert np.allclose(a, first_a, rtol=0.005)

first_a = matrices[0]
for a in matrices[1:]:
# we are not pixel perfect because of this bug: https://github.com/scverse/spatialdata/issues/165
assert np.allclose(a, first_a, rtol=0.005)
if not maintain_positioning:
assert np.allclose(first_a, np.eye(4))
assert np.allclose(first_a, np.eye(3))
else:
for element in elements:
before = full_sdata[element]
after = sdata[element]
assert get_model(after) == get_model(before)
data_extent_before = get_extent(before, coordinate_system="global")
data_extent_after = get_extent(after, coordinate_system="global")
# huge tolerance because of the bug with pixel perfectness
assert are_extents_equal(
data_extent_before, data_extent_after, atol=4
), f"data_extent_before: {data_extent_before}, data_extent_after: {data_extent_after} for element {element}"
data_extent_before = get_extent(full_sdata, coordinate_system="global")
data_extent_after = get_extent(sdata, coordinate_system="global")
# again, due to the "pixel perfect" bug, we use an absolute tolerance of 0.5
assert are_extents_equal(data_extent_before, data_extent_after, atol=0.5)


def test_validate_table_in_spatialdata(full_sdata):
Expand Down
Loading