From 614a84cfb955668b87a53d9b5d9642e2a662ccdb Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 21 May 2024 17:40:36 +0200 Subject: [PATCH 1/3] Fix the new way of handling cmap in matpltolib. This fix the matplotlib 3.9 problem related to this. --- .../benchmark/benchmark_motion_estimation.py | 2 +- .../benchmark/benchmark_peak_selection.py | 2 +- .../sortingcomponents/clustering/clustering_tools.py | 6 +++--- .../sortingcomponents/clustering/sliding_hdbscan.py | 2 +- src/spikeinterface/sortingcomponents/clustering/split.py | 2 +- src/spikeinterface/widgets/collision.py | 4 ++-- src/spikeinterface/widgets/motion.py | 2 +- src/spikeinterface/widgets/multicomparison.py | 8 ++++---- src/spikeinterface/widgets/utils.py | 5 +++-- 9 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index 30175288a3..3c5623f202 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -728,7 +728,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, colors=None, fig # n = self.motion.shape[1] # step = int(np.ceil(max(1, n / show_only))) -# colors = plt.cm.get_cmap("jet", n) +# colors = plt.colormaps["jet"].resampled(n) # for i in range(0, n, step): # ax = axs[0] # ax.plot(self.temporal_bins, self.gt_motion[:, i], lw=1.5, ls="--", color=colors(i)) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py index d3875ca33d..008de2d931 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py @@ -382,7 +382,7 @@ def create_benchmark(self, key): # import matplotlib -# my_cmap = plt.get_cmap(cmap) +# my_cmap = plt.colormaps[cmap] # cNorm = matplotlib.colors.Normalize(vmin=clim[0], vmax=clim[1]) # scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index d3a00c4e6e..083e0077f6 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -62,7 +62,7 @@ def _split_waveforms( local_feature_plot = local_feature unique_lab = np.unique(local_labels_with_noise) - cmap = plt.get_cmap("jet", unique_lab.size) + cmap = plt.colormaps["jet"].resampled(unique_lab.size) cmap = {k: cmap(l) for l, k in enumerate(unique_lab)} cmap[-1] = "k" active_ind = np.arange(local_feature.shape[0]) @@ -145,7 +145,7 @@ def _split_waveforms_nested( local_feature_plot = reducer.fit_transform(local_feature) unique_lab = np.unique(active_labels_with_noise) - cmap = plt.get_cmap("jet", unique_lab.size) + cmap = plt.colormaps["jet"].resampled(unique_lab.size) cmap = {k: cmap(l) for l, k in enumerate(unique_lab)} cmap[-1] = "k" cmap[-2] = "b" @@ -276,7 +276,7 @@ def auto_split_clustering( fig, ax = plt.subplots() plot_labels_set = np.unique(local_labels_with_noise) - cmap = plt.get_cmap("jet", plot_labels_set.size) + cmap = plt.colormaps["jet"].resampled(plot_labels_set.size) cmap = {k: cmap(l) for l, k in enumerate(plot_labels_set)} cmap[-1] = "k" cmap[-2] = "b" diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py index 7e7a8de1d7..2ae22ce07d 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_hdbscan.py @@ -349,7 +349,7 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d): wfs_no_noise = wfs[: -noise.shape[0]] fig, axs = plt.subplots(ncols=3) - cmap = plt.get_cmap("jet", np.unique(local_labels).size) + cmap = plt.colormaps["jet"].resampled(np.unique(local_labels).size) cmap = {label: cmap(l) for l, label in enumerate(local_labels_set)} cmap[-1] = "k" for label in local_labels_set: diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index ceeaeb6633..45f2f44753 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -254,7 +254,7 @@ def split( import matplotlib.pyplot as plt labels_set = np.setdiff1d(possible_labels, [-1]) - colors = plt.get_cmap("tab10", len(labels_set)) + colors = plt.colormaps["tab10"].resampled(len(labels_set)) colors = {k: colors(i) for i, k in enumerate(labels_set)} colors[-1] = "k" fix, axs = plt.subplots(nrows=2) diff --git a/src/spikeinterface/widgets/collision.py b/src/spikeinterface/widgets/collision.py index a5b5891110..34f65a2f89 100644 --- a/src/spikeinterface/widgets/collision.py +++ b/src/spikeinterface/widgets/collision.py @@ -136,7 +136,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax1.set_xlabel("lag (ms)") elif dp.mode == "lines": - my_cmap = plt.get_cmap(dp.cmap) + my_cmap = plt.colormaps[dp.cmap] cNorm = matplotlib.colors.Normalize(vmin=dp.similarity_bins.min(), vmax=dp.similarity_bins.max()) scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap) @@ -245,7 +245,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): study = dp.study - my_cmap = plt.get_cmap(dp.cmap) + my_cmap = plt.colormaps[dp.cmap] cNorm = matplotlib.colors.Normalize(vmin=dp.similarity_bins.min(), vmax=dp.similarity_bins.max()) scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap) study.precompute_scores_by_similarities( diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 2e4efc82b0..9d64c89e46 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -128,7 +128,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if dp.scatter_decimate is not None: amps = amps[:: dp.scatter_decimate] amps_abs = amps_abs[:: dp.scatter_decimate] - cmap = plt.get_cmap(dp.amplitude_cmap) + cmap = plt.colormaps[dp.amplitude_cmap] if dp.amplitude_clim is None: amps = amps_abs amps /= q_95 diff --git a/src/spikeinterface/widgets/multicomparison.py b/src/spikeinterface/widgets/multicomparison.py index 78693aacc2..2d4a22a2b3 100644 --- a/src/spikeinterface/widgets/multicomparison.py +++ b/src/spikeinterface/widgets/multicomparison.py @@ -87,7 +87,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): nodelist=sorted(g.nodes), edge_color=edge_col, alpha=dp.alpha_edges, - edge_cmap=plt.cm.get_cmap(dp.edge_cmap), + edge_cmap=plt.colormaps[dp.edge_cmap], edge_vmin=mcmp.match_score, edge_vmax=1, ax=self.ax, @@ -106,7 +106,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt norm = mpl_colors.Normalize(vmin=mcmp.match_score, vmax=1) - cmap = plt.cm.get_cmap(dp.edge_cmap) + cmap = plt.colormaps[dp.edge_cmap] m = plt.cm.ScalarMappable(norm=norm, cmap=cmap) self.figure.colorbar(m) @@ -159,7 +159,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) mcmp = dp.multi_comparison - cmap = plt.get_cmap(dp.cmap) + cmap = plt.colormaps[dp.cmap] colors = np.array([cmap(i) for i in np.linspace(0.1, 0.8, len(mcmp.name_list))]) sg_names, sg_units = mcmp.compute_subgraphs() # fraction of units with agreement > threshold @@ -242,7 +242,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): backend_kwargs["ncols"] = len(name_list) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - cmap = plt.get_cmap(dp.cmap) + cmap = plt.colormaps[dp.cmap] colors = np.array([cmap(i) for i in np.linspace(0.1, 0.8, len(mcmp.name_list))]) sg_names, sg_units = mcmp.compute_subgraphs() # fraction of units with agreement > threshold diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index 29e6474ee9..9536941c07 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -76,7 +76,8 @@ def get_some_colors(keys, color_engine="auto", map_name="gist_ncar", format="RGB elif color_engine == "matplotlib": # some map have black or white at border so +10 margin = max(4, int(N * 0.08)) - cmap = plt.get_cmap(map_name, N + 2 * margin) + cmap = plt.colormaps[map_name].resampled(N + 2 * margin) + colors = [cmap(i + margin) for i, key in enumerate(keys)] elif color_engine == "colorsys": @@ -153,7 +154,7 @@ def array_to_image( num_channels = data.shape[1] spacing = int(num_channels * spatial_zoom[1] * row_spacing) - cmap = plt.get_cmap(colormap) + cmap = plt.colormaps[colormap] zoomed_data = zoom(data, spatial_zoom) num_timepoints_after_scaling, num_channels_after_scaling = zoomed_data.shape num_timepoints_per_row_after_scaling = int(np.min([num_timepoints_per_row, num_timepoints]) * spatial_zoom[0]) From cc48a409fcb65bbb9477116ff6a1a42be0395309 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 21 May 2024 17:41:56 +0200 Subject: [PATCH 2/3] Try to remove the mpl boundary to run tests --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 89ea05e5bf..bc04a1bcd5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,7 +95,7 @@ full = [ "scikit-learn", "networkx", "distinctipy", - "matplotlib<3.9", # See https://github.com/SpikeInterface/spikeinterface/issues/2863 + "matplotlib", "cuda-python; platform_system != 'Darwin'", "numba", ] From 4b67b2099b6a1a1682b33a656a25bd9fdc2c13a2 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 21 May 2024 20:52:31 +0200 Subject: [PATCH 3/3] mpl 3.6 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bc04a1bcd5..d040a4a36b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,7 +95,7 @@ full = [ "scikit-learn", "networkx", "distinctipy", - "matplotlib", + "matplotlib>=3.6", # matplotlib.colormaps "cuda-python; platform_system != 'Darwin'", "numba", ]