Skip to content

Commit

Permalink
Updated named_arrays.plt.scatter() to handle matplotlib >= 3.8
Browse files Browse the repository at this point in the history
  • Loading branch information
byrdie committed Jan 15, 2024
1 parent 78f1089 commit 3a07bec
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 29 deletions.
24 changes: 21 additions & 3 deletions named_arrays/_scalars/scalar_named_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,12 @@ def plt_scatter(
ax = plt.gca()
ax = na.as_named_array(ax)

Check warning on line 369 in named_arrays/_scalars/scalar_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L367-L369

Added lines #L367 - L369 were not covered by tests

shape = na.shape_broadcasted(*args, s, c, ax, where)
shape_c = c.shape
if "rgba" in c.shape:
shape_c.pop("rgba")

Check warning on line 373 in named_arrays/_scalars/scalar_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L371-L373

Added lines #L371 - L373 were not covered by tests

shape = na.shape_broadcasted(*args, s, ax, where)
shape = na.broadcast_shapes(shape, shape_c)

Check warning on line 376 in named_arrays/_scalars/scalar_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L375-L376

Added lines #L375 - L376 were not covered by tests

shape_orthogonal = ax.shape

Check warning on line 378 in named_arrays/_scalars/scalar_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L378

Added line #L378 was not covered by tests

Expand All @@ -382,7 +387,10 @@ def plt_scatter(
if np.issubdtype(na.get_dtype(c), np.number):
c = na.broadcast_to(c, shape)

Check warning on line 388 in named_arrays/_scalars/scalar_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L387-L388

Added lines #L387 - L388 were not covered by tests
else:
c = na.broadcast_to(c, shape_orthogonal)
if "rgba" in c.shape:
c = na.broadcast_to(c, shape_orthogonal | dict(rgba=c.shape["rgba"]))

Check warning on line 391 in named_arrays/_scalars/scalar_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L390-L391

Added lines #L390 - L391 were not covered by tests
else:
c = na.broadcast_to(c, shape_orthogonal)

Check warning on line 393 in named_arrays/_scalars/scalar_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L393

Added line #L393 was not covered by tests

where = where.broadcast_to(shape)

Check warning on line 395 in named_arrays/_scalars/scalar_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L395

Added line #L395 was not covered by tests

Expand All @@ -403,9 +411,19 @@ def plt_scatter(

for index in na.ndindex(shape_orthogonal):
func_matplotlib = getattr(ax[index].ndarray, "scatter")
args_index = tuple(arg[index].ndarray for arg in args)
args_index = tuple(arg[index].ndarray.reshape(-1) for arg in args)

Check warning on line 414 in named_arrays/_scalars/scalar_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L412-L414

Added lines #L412 - L414 were not covered by tests

s_index = s[index].ndarray
if s_index is not None:
s_index = s_index.reshape(-1)

Check warning on line 418 in named_arrays/_scalars/scalar_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L416-L418

Added lines #L416 - L418 were not covered by tests

c_index = c[index].ndarray
if c_index is not None:
if "rgba" in c.shape:
c_index = c[index].ndarray.reshape(-1, c.shape["rgba"])

Check warning on line 423 in named_arrays/_scalars/scalar_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L420-L423

Added lines #L420 - L423 were not covered by tests
else:
c_index = c[index].ndarray.reshape(-1)

Check warning on line 425 in named_arrays/_scalars/scalar_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L425

Added line #L425 was not covered by tests

kwargs_index = {k: kwargs[k][index].ndarray for k in kwargs}
result[index] = func_matplotlib(

Check warning on line 428 in named_arrays/_scalars/scalar_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/scalar_named_array_functions.py#L427-L428

Added lines #L427 - L428 were not covered by tests
*args_index,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def plt_scatter(
try:
args = tuple(uncertainties._normalize(arg) for arg in args)
s = uncertainties._normalize(s)
c = uncertainties._normalize(c) if c is not None else c
c = uncertainties._normalize(c)
where = uncertainties._normalize(where)
kwargs = {k: uncertainties._normalize(kwargs[k]) for k in kwargs}
except na.UncertainScalarTypeError:
Expand Down Expand Up @@ -295,33 +295,42 @@ def plt_scatter(
else:
kwargs["alpha"] = na.UncertainScalarArray(1, alpha)

Check warning on line 296 in named_arrays/_scalars/uncertainties/uncertainties_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/uncertainties/uncertainties_named_array_functions.py#L296

Added line #L296 was not covered by tests

if c is None:
c = na.ScalarArray.empty(shape=ax.shape, dtype=object)
result_nominal = na.plt.scatter(

Check warning on line 298 in named_arrays/_scalars/uncertainties/uncertainties_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/uncertainties/uncertainties_named_array_functions.py#L298

Added line #L298 was not covered by tests
*[na.as_named_array(arg.nominal) for arg in args],
s=s.nominal,
c=c.nominal,
ax=ax,
where=where.nominal,
components=components,
**{k: kwargs[k].nominal for k in kwargs},
)

if c.distribution is None:
c_distribution = na.ScalarArray.zeros(shape=ax.shape | dict(rgba=4))
for index in ax.ndindex():
c[index] = next(ax[index].ndarray._get_lines.prop_cycler)['color']
c = na.UncertainScalarArray(c, c)
facecolor = result_nominal[index].ndarray.get_facecolor()[0]
c_distribution[index] = na.ScalarArray(

Check warning on line 312 in named_arrays/_scalars/uncertainties/uncertainties_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/uncertainties/uncertainties_named_array_functions.py#L308-L312

Added lines #L308 - L312 were not covered by tests
ndarray=facecolor,
axes="rgba",
)
c.distribution = c_distribution

Check warning on line 316 in named_arrays/_scalars/uncertainties/uncertainties_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/uncertainties/uncertainties_named_array_functions.py#L316

Added line #L316 was not covered by tests

result_distribution = na.plt.scatter(

Check warning on line 318 in named_arrays/_scalars/uncertainties/uncertainties_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/uncertainties/uncertainties_named_array_functions.py#L318

Added line #L318 was not covered by tests
*[na.as_named_array(arg.distribution) for arg in args],
s=s.distribution,
c=c.distribution,
ax=ax,
where=where.distribution,
components=components,
**{k: kwargs[k].distribution for k in kwargs},
)

result = na.UncertainScalarArray(

Check warning on line 328 in named_arrays/_scalars/uncertainties/uncertainties_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/uncertainties/uncertainties_named_array_functions.py#L328

Added line #L328 was not covered by tests
nominal=na.plt.scatter(
*[na.as_named_array(arg.nominal) for arg in args],
s=s.nominal,
c=c.nominal,
ax=ax,
where=where.nominal,
components=components,
**{k: kwargs[k].nominal for k in kwargs},
),
distribution=na.plt.scatter(
*[na.as_named_array(arg.distribution) for arg in args],
s=s.distribution,
c=c.distribution,
ax=ax,
where=where.distribution,
components=components,
**{k: kwargs[k].distribution for k in kwargs},
)
nominal=result_nominal,
distribution=result_distribution,
)

return result


@_implements(na.jacobian)

Check warning on line 336 in named_arrays/_scalars/uncertainties/uncertainties_named_array_functions.py

View check run for this annotation

Codecov / codecov/patch

named_arrays/_scalars/uncertainties/uncertainties_named_array_functions.py#L333-L336

Added lines #L333 - L336 were not covered by tests
Expand Down
6 changes: 3 additions & 3 deletions named_arrays/plt.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def scatter(
The component names of ``*args`` to plot, helpful if ``*args`` are an instance of
:class:`named_arrays.AbstractVectorArray`.
kwargs
Additional keyword arguments passed to :meth:`matplotlib.axes.Axes.plot`.
Additional keyword arguments passed to :meth:`matplotlib.axes.Axes.scatter`.
These can be instances of :class:`named_arrays.AbstractArray`.
Examples
Expand All @@ -296,7 +296,7 @@ def scatter(
import matplotlib.pyplot as plt
import named_arrays as na
x = na.linspace(0, 2 * np.pi, axis="x", num=101)
x = na.linspace(0, 2 * np.pi, axis="x", num=51)
y = np.sin(x)
plt.figure();
Expand Down Expand Up @@ -329,7 +329,7 @@ def scatter(
fig, ax = na.plt.subplots(axis_rows="z", nrows=z.shape["z"], sharex=True)
na.plt.scatter(x, y, ax=ax, axis="x");
na.plt.scatter(x, y, ax=ax);
"""
if transformation is not None:
args = tuple(transformation(arg) for arg in args)
Expand Down

0 comments on commit 3a07bec

Please sign in to comment.