Skip to content

Commit

Permalink
Fix the new way of handling cmap in matpltolib. This fix the matplotl…
Browse files Browse the repository at this point in the history
…ib 3.9 problem related to this.
  • Loading branch information
samuelgarcia committed May 21, 2024
1 parent 70ebe17 commit 614a84c
Show file tree
Hide file tree
Showing 9 changed files with 17 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, colors=None, fig

# n = self.motion.shape[1]
# step = int(np.ceil(max(1, n / show_only)))
# colors = plt.cm.get_cmap("jet", n)
# colors = plt.colormaps["jet"].resampled(n)
# for i in range(0, n, step):
# ax = axs[0]
# ax.plot(self.temporal_bins, self.gt_motion[:, i], lw=1.5, ls="--", color=colors(i))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def create_benchmark(self, key):

# import matplotlib

# my_cmap = plt.get_cmap(cmap)
# my_cmap = plt.colormaps[cmap]
# cNorm = matplotlib.colors.Normalize(vmin=clim[0], vmax=clim[1])
# scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _split_waveforms(
local_feature_plot = local_feature

unique_lab = np.unique(local_labels_with_noise)
cmap = plt.get_cmap("jet", unique_lab.size)
cmap = plt.colormaps["jet"].resampled(unique_lab.size)
cmap = {k: cmap(l) for l, k in enumerate(unique_lab)}
cmap[-1] = "k"
active_ind = np.arange(local_feature.shape[0])
Expand Down Expand Up @@ -145,7 +145,7 @@ def _split_waveforms_nested(
local_feature_plot = reducer.fit_transform(local_feature)

unique_lab = np.unique(active_labels_with_noise)
cmap = plt.get_cmap("jet", unique_lab.size)
cmap = plt.colormaps["jet"].resampled(unique_lab.size)
cmap = {k: cmap(l) for l, k in enumerate(unique_lab)}
cmap[-1] = "k"
cmap[-2] = "b"
Expand Down Expand Up @@ -276,7 +276,7 @@ def auto_split_clustering(

fig, ax = plt.subplots()
plot_labels_set = np.unique(local_labels_with_noise)
cmap = plt.get_cmap("jet", plot_labels_set.size)
cmap = plt.colormaps["jet"].resampled(plot_labels_set.size)
cmap = {k: cmap(l) for l, k in enumerate(plot_labels_set)}
cmap[-1] = "k"
cmap[-2] = "b"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def _find_clusters(cls, recording, peaks, wfs_arrays, sparsity_mask, noise, d):
wfs_no_noise = wfs[: -noise.shape[0]]

fig, axs = plt.subplots(ncols=3)
cmap = plt.get_cmap("jet", np.unique(local_labels).size)
cmap = plt.colormaps["jet"].resampled(np.unique(local_labels).size)
cmap = {label: cmap(l) for l, label in enumerate(local_labels_set)}
cmap[-1] = "k"
for label in local_labels_set:
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sortingcomponents/clustering/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def split(
import matplotlib.pyplot as plt

labels_set = np.setdiff1d(possible_labels, [-1])
colors = plt.get_cmap("tab10", len(labels_set))
colors = plt.colormaps["tab10"].resampled(len(labels_set))
colors = {k: colors(i) for i, k in enumerate(labels_set)}
colors[-1] = "k"
fix, axs = plt.subplots(nrows=2)
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/widgets/collision.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):

ax1.set_xlabel("lag (ms)")
elif dp.mode == "lines":
my_cmap = plt.get_cmap(dp.cmap)
my_cmap = plt.colormaps[dp.cmap]
cNorm = matplotlib.colors.Normalize(vmin=dp.similarity_bins.min(), vmax=dp.similarity_bins.max())
scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap)

Expand Down Expand Up @@ -245,7 +245,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):

study = dp.study

my_cmap = plt.get_cmap(dp.cmap)
my_cmap = plt.colormaps[dp.cmap]
cNorm = matplotlib.colors.Normalize(vmin=dp.similarity_bins.min(), vmax=dp.similarity_bins.max())
scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap)
study.precompute_scores_by_similarities(
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
if dp.scatter_decimate is not None:
amps = amps[:: dp.scatter_decimate]
amps_abs = amps_abs[:: dp.scatter_decimate]
cmap = plt.get_cmap(dp.amplitude_cmap)
cmap = plt.colormaps[dp.amplitude_cmap]
if dp.amplitude_clim is None:
amps = amps_abs
amps /= q_95
Expand Down
8 changes: 4 additions & 4 deletions src/spikeinterface/widgets/multicomparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
nodelist=sorted(g.nodes),
edge_color=edge_col,
alpha=dp.alpha_edges,
edge_cmap=plt.cm.get_cmap(dp.edge_cmap),
edge_cmap=plt.colormaps[dp.edge_cmap],
edge_vmin=mcmp.match_score,
edge_vmax=1,
ax=self.ax,
Expand All @@ -106,7 +106,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt

norm = mpl_colors.Normalize(vmin=mcmp.match_score, vmax=1)
cmap = plt.cm.get_cmap(dp.edge_cmap)
cmap = plt.colormaps[dp.edge_cmap]
m = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
self.figure.colorbar(m)

Expand Down Expand Up @@ -159,7 +159,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)

mcmp = dp.multi_comparison
cmap = plt.get_cmap(dp.cmap)
cmap = plt.colormaps[dp.cmap]
colors = np.array([cmap(i) for i in np.linspace(0.1, 0.8, len(mcmp.name_list))])
sg_names, sg_units = mcmp.compute_subgraphs()
# fraction of units with agreement > threshold
Expand Down Expand Up @@ -242,7 +242,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
backend_kwargs["ncols"] = len(name_list)
self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)

cmap = plt.get_cmap(dp.cmap)
cmap = plt.colormaps[dp.cmap]
colors = np.array([cmap(i) for i in np.linspace(0.1, 0.8, len(mcmp.name_list))])
sg_names, sg_units = mcmp.compute_subgraphs()
# fraction of units with agreement > threshold
Expand Down
5 changes: 3 additions & 2 deletions src/spikeinterface/widgets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def get_some_colors(keys, color_engine="auto", map_name="gist_ncar", format="RGB
elif color_engine == "matplotlib":
# some map have black or white at border so +10
margin = max(4, int(N * 0.08))
cmap = plt.get_cmap(map_name, N + 2 * margin)
cmap = plt.colormaps[map_name].resampled(N + 2 * margin)

colors = [cmap(i + margin) for i, key in enumerate(keys)]

elif color_engine == "colorsys":
Expand Down Expand Up @@ -153,7 +154,7 @@ def array_to_image(
num_channels = data.shape[1]
spacing = int(num_channels * spatial_zoom[1] * row_spacing)

cmap = plt.get_cmap(colormap)
cmap = plt.colormaps[colormap]
zoomed_data = zoom(data, spatial_zoom)
num_timepoints_after_scaling, num_channels_after_scaling = zoomed_data.shape
num_timepoints_per_row_after_scaling = int(np.min([num_timepoints_per_row, num_timepoints]) * spatial_zoom[0])
Expand Down

0 comments on commit 614a84c

Please sign in to comment.