Skip to content

Commit

Permalink
tests: test the conversion from 3d data
Browse files Browse the repository at this point in the history
  • Loading branch information
Eoghan O'Connell committed Jan 22, 2025
1 parent 286ed5a commit da0f046
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 3 deletions.
17 changes: 14 additions & 3 deletions qpretrieve/data_array_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,20 @@ def convert_data_to_3d_array_layout(data):


def convert_3d_data_to_array_layout(data, array_layout):
"""Convert the 3d data to the desired `array_layout`"""
assert array_layout in get_allowed_array_layouts()
assert len(data.shape) == 3, "the data should be 3d"
"""Convert the 3d data to the desired `array_layout`.
Notes
-----
Currently, this function is limited to converting from 3d to other
array layouts. Perhaps if there is demand in the future,
this can be generalised for other conversions.
"""
assert array_layout in get_allowed_array_layouts(), (
f"`array_layout` not allowed. "
f"Allowed layouts are: {get_allowed_array_layouts()}.")
assert len(data.shape) == 3, (
f"The data should be 3d, got {len(data.shape)=}")
data = data.copy()
if array_layout == "rgb":
data = _convert_3d_to_rgb(data)
Expand Down
55 changes: 55 additions & 0 deletions tests/test_array_layout_convert_from_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import numpy as np
import pytest

from qpretrieve.data_array_layout import (
convert_3d_data_to_array_layout,
_convert_3d_to_2d, _convert_3d_to_rgba, _convert_3d_to_rgb,
)


def test_convert_3d_data_to_2d():
data = np.zeros(shape=(10, 256, 256))
array_layout = "2d"

data_new = convert_3d_data_to_array_layout(data, array_layout)
data_direct = _convert_3d_to_2d(data) # this is the internal function

assert np.array_equal(data[0], data_new)
assert data_new.shape == data_direct.shape == (256, 256)
assert np.array_equal(data_direct, data_new)


def test_convert_3d_data_to_rgb():
data = np.zeros(shape=(10, 256, 256))
array_layout = "rgb"

data_new = convert_3d_data_to_array_layout(data, array_layout)
data_direct = _convert_3d_to_rgb(data) # this is the internal function

assert data_new.shape == data_direct.shape == (256, 256, 3)
assert np.array_equal(data_direct, data_new)


def test_convert_3d_data_to_rgba():
data = np.zeros(shape=(10, 256, 256))
array_layout = "rgba"

data_new = convert_3d_data_to_array_layout(data, array_layout)
data_direct = _convert_3d_to_rgba(data) # this is the internal function

assert data_new.shape == data_direct.shape == (256, 256, 4)
assert np.array_equal(data_direct, data_new)


def test_convert_3d_data_to_array_layout_bad_input():
data = np.zeros(shape=(10, 256, 256))
array_layout = "5d"

with pytest.raises(AssertionError, match="`array_layout` not allowed."):
convert_3d_data_to_array_layout(data, array_layout)

data = np.zeros(shape=(256, 256))
array_layout = "2d"

with pytest.raises(AssertionError, match="The data should be 3d"):
convert_3d_data_to_array_layout(data, array_layout)
File renamed without changes.

0 comments on commit da0f046

Please sign in to comment.