Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed a bug in named_arrays.random.binomial() where the units weren't being handled properly. #101

Merged
merged 2 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion named_arrays/_scalars/scalar_named_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
na.random.uniform,
na.random.normal,
na.random.poisson,
na.random.binomial,
)
PLT_PLOT_LIKE_FUNCTIONS = (
na.plt.plot,
Expand Down Expand Up @@ -579,6 +578,53 @@
)


@_implements(na.random.binomial)
def random_binomial(
n: int | u.Quantity | na.AbstractScalarArray,
p: float | na.AbstractScalarArray,
shape_random: None | dict[str, int] = None,
seed: None | int = None,
):
try:
n = scalars._normalize(n)
p = scalars._normalize(p)
except na.ScalarTypeError:
return NotImplemented

Check warning on line 592 in named_arrays/_scalars/scalar_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L591-L592

Added lines #L591 - L592 were not covered by tests

if shape_random is None:
shape_random = dict()

Check warning on line 595 in named_arrays/_scalars/scalar_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L595

Added line #L595 was not covered by tests

shape_base = na.shape_broadcasted(n, p)
shape = na.broadcast_shapes(shape_base, shape_random)

n = n.ndarray_aligned(shape)
p = p.ndarray_aligned(shape)

unit = na.unit(n)

if unit is not None:
n = n.value

Check warning on line 606 in named_arrays/_scalars/scalar_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L606

Added line #L606 was not covered by tests

if seed is None:
func = np.random.binomial
else:
func = np.random.default_rng(seed).binomial

Check warning on line 611 in named_arrays/_scalars/scalar_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L611

Added line #L611 was not covered by tests

value = func(
n=n,
p=p,
size=tuple(shape.values()),
)

if unit is not None:
value = value << unit

Check warning on line 620 in named_arrays/_scalars/scalar_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L620

Added line #L620 was not covered by tests

return na.ScalarArray(
ndarray=value,
axes=tuple(shape.keys()),
)


def plt_plot_like(
func: Callable,
*args: na.AbstractScalarArray,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
]

ASARRAY_LIKE_FUNCTIONS = named_arrays._scalars.scalar_named_array_functions.ASARRAY_LIKE_FUNCTIONS
RANDOM_FUNCTIONS = named_arrays._scalars.scalar_named_array_functions.RANDOM_FUNCTIONS
RANDOM_FUNCTIONS = (
na.random.uniform,
na.random.normal,
na.random.poisson,
na.random.binomial,
)
PLT_PLOT_LIKE_FUNCTIONS = named_arrays._scalars.scalar_named_array_functions.PLT_PLOT_LIKE_FUNCTIONS
NDFILTER_FUNCTIONS = named_arrays._scalars.scalar_named_array_functions.NDFILTER_FUNCTIONS
HANDLED_FUNCTIONS = dict()
Expand Down
7 changes: 6 additions & 1 deletion named_arrays/_vectors/vector_named_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
OutputT = TypeVar("OutputT", bound="float | u.Quantity | na.AbstractVectorArray")

ASARRAY_LIKE_FUNCTIONS = named_arrays._scalars.scalar_named_array_functions.ASARRAY_LIKE_FUNCTIONS
RANDOM_FUNCTIONS = named_arrays._scalars.scalar_named_array_functions.RANDOM_FUNCTIONS
RANDOM_FUNCTIONS = (
na.random.uniform,
na.random.normal,
na.random.poisson,
na.random.binomial,
)
PLT_PLOT_LIKE_FUNCTIONS = named_arrays._scalars.scalar_named_array_functions.PLT_PLOT_LIKE_FUNCTIONS
NDFILTER_FUNCTIONS = named_arrays._scalars.scalar_named_array_functions.NDFILTER_FUNCTIONS
HANDLED_FUNCTIONS = dict()
Expand Down
57 changes: 57 additions & 0 deletions named_arrays/tests/test_random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pytest
import numpy as np
import astropy.units as u
import named_arrays as na


@pytest.mark.parametrize(
argnames="n",
argvalues=[
10,
(11 * u.photon).astype(int),
na.ScalarArray(12),
(na.arange(1, 10, axis="x") << u.photon).astype(int),
na.Cartesian2dVectorArray(10, 11),
],
)
@pytest.mark.parametrize(
argnames="p",
argvalues=[
0.5,
na.ScalarArray(0.51),
na.linspace(0.4, 0.5, axis="p", num=5),
na.UniformUncertainScalarArray(0.5, width=0.1),
na.Cartesian2dVectorArray(0.5, 0.6),
],
)
@pytest.mark.parametrize(
argnames="shape_random",
argvalues=[
None,
dict(_s=6),
],
)
@pytest.mark.parametrize(
argnames="seed",
argvalues=[
None,
42,
],
)
def test_binomial(
n: int | u.Quantity | na.AbstractScalar | na.AbstractVectorArray,
p: float | na.AbstractScalar | na.AbstractVectorArray,
shape_random: None | dict[str, int],
seed: None | int,
):
result = na.random.binomial(

Check warning on line 47 in named_arrays/tests/test_random.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/tests/test_random.py#L47

Added line #L47 was not covered by tests
n=n,
p=p,
shape_random=shape_random,
seed=seed,
)

assert na.unit(result) == na.unit(n)

Check warning on line 54 in named_arrays/tests/test_random.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/tests/test_random.py#L54

Added line #L54 was not covered by tests

assert np.all(result >= 0)
assert np.all(result <= n)

Check warning on line 57 in named_arrays/tests/test_random.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/tests/test_random.py#L56-L57

Added lines #L56 - L57 were not covered by tests
Loading