Skip to content

Commit

Permalink
Merge pull request #2891 from samuelgarcia/mpl_cmap
Browse files Browse the repository at this point in the history
Fix the new way of handling cmap in matpltolib. This fix the matplotib 3.9 problem related to this.
  • Loading branch information
alejoe91 authored May 23, 2024
2 parents 4f0bd0e + cd3dc7a commit bd88c1e
Show file tree
Hide file tree
Showing 10 changed files with 18 additions and 17 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ full = [
"scikit-learn",
"networkx",
"distinctipy",
"matplotlib<3.9", # See https://github.com/SpikeInterface/spikeinterface/issues/2863
"matplotlib>=3.6", # matplotlib.colormaps
"cuda-python; platform_system != 'Darwin'",
"numba",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, colors=None, fig

# n = self.motion.shape[1]
# step = int(np.ceil(max(1, n / show_only)))
# colors = plt.cm.get_cmap("jet", n)
# colors = plt.colormaps["jet"].resampled(n)
# for i in range(0, n, step):
# ax = axs[0]
# ax.plot(self.temporal_bins, self.gt_motion[:, i], lw=1.5, ls="--", color=colors(i))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def create_benchmark(self, key):

# import matplotlib

# my_cmap = plt.get_cmap(cmap)
# my_cmap = plt.colormaps[cmap]
# cNorm = matplotlib.colors.Normalize(vmin=clim[0], vmax=clim[1])
# scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _split_waveforms(
local_feature_plot = local_feature

unique_lab = np.unique(local_labels_with_noise)
cmap = plt.get_cmap("jet", unique_lab.size)
cmap = plt.colormaps["jet"].resampled(unique_lab.size)
cmap = {k: cmap(l) for l, k in enumerate(unique_lab)}
cmap[-1] = "k"
active_ind = np.arange(local_feature.shape[0])
Expand Down Expand Up @@ -145,7 +145,7 @@ def _split_waveforms_nested(
local_feature_plot = reducer.fit_transform(local_feature)

unique_lab = np.unique(active_labels_with_noise)
cmap = plt.get_cmap("jet", unique_lab.size)
cmap = plt.colormaps["jet"].resampled(unique_lab.size)
cmap = {k: cmap(l) for l, k in enumerate(unique_lab)}
cmap[-1] = "k"
cmap[-2] = "b"
Expand Down Expand Up @@ -276,7 +276,7 @@ def auto_split_clustering(

fig, ax = plt.subplots()
plot_labels_set = np.unique(local_labels_with_noise)
cmap = plt.get_cmap("jet", plot_labels_set.size)
cmap = plt.colormaps["jet"].resampled(plot_labels_set.size)
cmap = {k: cmap(l) for l, k in enumerate(plot_labels_set)}
cmap[-1] = "k"
cmap[-2] = "b"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d):
wfs_no_noise = wfs[: -noise.shape[0]]

fig, axs = plt.subplots(ncols=3)
cmap = plt.get_cmap("jet", np.unique(local_labels).size)
cmap = plt.colormaps["jet"].resampled(np.unique(local_labels).size)
cmap = {label: cmap(l) for l, label in enumerate(local_labels_set)}
cmap[-1] = "k"
for label in local_labels_set:
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sortingcomponents/clustering/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def split(
import matplotlib.pyplot as plt

labels_set = np.setdiff1d(possible_labels, [-1])
colors = plt.get_cmap("tab10", len(labels_set))
colors = plt.colormaps["tab10"].resampled(len(labels_set))
colors = {k: colors(i) for i, k in enumerate(labels_set)}
colors[-1] = "k"
fix, axs = plt.subplots(nrows=2)
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/widgets/collision.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):

ax1.set_xlabel("lag (ms)")
elif dp.mode == "lines":
my_cmap = plt.get_cmap(dp.cmap)
my_cmap = plt.colormaps[dp.cmap]
cNorm = matplotlib.colors.Normalize(vmin=dp.similarity_bins.min(), vmax=dp.similarity_bins.max())
scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap)

Expand Down Expand Up @@ -245,7 +245,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):

study = dp.study

my_cmap = plt.get_cmap(dp.cmap)
my_cmap = plt.colormaps[dp.cmap]
cNorm = matplotlib.colors.Normalize(vmin=dp.similarity_bins.min(), vmax=dp.similarity_bins.max())
scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap)
study.precompute_scores_by_similarities(
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
if dp.scatter_decimate is not None:
amps = amps[:: dp.scatter_decimate]
amps_abs = amps_abs[:: dp.scatter_decimate]
cmap = plt.get_cmap(dp.amplitude_cmap)
cmap = plt.colormaps[dp.amplitude_cmap]
if dp.amplitude_clim is None:
amps = amps_abs
amps /= q_95
Expand Down
8 changes: 4 additions & 4 deletions src/spikeinterface/widgets/multicomparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
nodelist=sorted(g.nodes),
edge_color=edge_col,
alpha=dp.alpha_edges,
edge_cmap=plt.cm.get_cmap(dp.edge_cmap),
edge_cmap=plt.colormaps[dp.edge_cmap],
edge_vmin=mcmp.match_score,
edge_vmax=1,
ax=self.ax,
Expand All @@ -106,7 +106,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt

norm = mpl_colors.Normalize(vmin=mcmp.match_score, vmax=1)
cmap = plt.cm.get_cmap(dp.edge_cmap)
cmap = plt.colormaps[dp.edge_cmap]
m = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
self.figure.colorbar(m)

Expand Down Expand Up @@ -159,7 +159,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)

mcmp = dp.multi_comparison
cmap = plt.get_cmap(dp.cmap)
cmap = plt.colormaps[dp.cmap]
colors = np.array([cmap(i) for i in np.linspace(0.1, 0.8, len(mcmp.name_list))])
sg_names, sg_units = mcmp.compute_subgraphs()
# fraction of units with agreement > threshold
Expand Down Expand Up @@ -242,7 +242,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
backend_kwargs["ncols"] = len(name_list)
self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)

cmap = plt.get_cmap(dp.cmap)
cmap = plt.colormaps[dp.cmap]
colors = np.array([cmap(i) for i in np.linspace(0.1, 0.8, len(mcmp.name_list))])
sg_names, sg_units = mcmp.compute_subgraphs()
# fraction of units with agreement > threshold
Expand Down
5 changes: 3 additions & 2 deletions src/spikeinterface/widgets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def get_some_colors(keys, color_engine="auto", map_name="gist_ncar", format="RGB
elif color_engine == "matplotlib":
# some map have black or white at border so +10
margin = max(4, int(N * 0.08))
cmap = plt.get_cmap(map_name, N + 2 * margin)
cmap = plt.colormaps[map_name].resampled(N + 2 * margin)

colors = [cmap(i + margin) for i, key in enumerate(keys)]

elif color_engine == "colorsys":
Expand Down Expand Up @@ -160,7 +161,7 @@ def array_to_image(
num_channels = data.shape[1]
spacing = int(num_channels * spatial_zoom[1] * row_spacing)

cmap = plt.get_cmap(colormap)
cmap = plt.colormaps[colormap]
zoomed_data = zoom(data, spatial_zoom)
num_timepoints_after_scaling, num_channels_after_scaling = zoomed_data.shape
num_timepoints_per_row_after_scaling = int(np.min([num_timepoints_per_row, num_timepoints]) * spatial_zoom[0])
Expand Down

0 comments on commit bd88c1e

Please sign in to comment.