Skip to content
This repository has been archived by the owner on Jun 6, 2023. It is now read-only.

Commit

Permalink
Merge pull request #72 from SpikeInterface/activity_update
Browse files Browse the repository at this point in the history
Modified plot_activity_map with new spiketoolkit changes
  • Loading branch information
alejoe91 authored Nov 25, 2020
2 parents e0c0528 + cc8bc7d commit 0c27a8c
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 21 deletions.
4 changes: 4 additions & 0 deletions spikewidgets/tests/test_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def test_spectrogram(self):
def test_geometry(self):
sw.plot_electrode_geometry(self._RX)

def test_activitymap(self):
sw.plot_activity_map(self._RX, activity='rate')
sw.plot_activity_map(self._RX, activity='amplitude')

def test_unitwaveforms(self):
sw.plot_unit_waveforms(self._RX, self._SX)
fig, axes = plt.subplots(self.num_units, 1)
Expand Down
68 changes: 47 additions & 21 deletions spikewidgets/widgets/mapswidget/activitymapwidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import spiketoolkit as st
import matplotlib.pylab as plt
import matplotlib as mpl
from mpl_toolkits.axes_grid1 import make_axes_locatable
from ..utils import LabeledRectangle
from spikewidgets.widgets.basewidget import BaseWidget
from mpl_toolkits.axes_grid1.inset_locator import inset_axes


def plot_activity_map(recording, channel_ids=None, trange=None, cmap='viridis', background='on', label_color='r',
def plot_activity_map(recording, channel_ids=None, trange=None, activity='rate', log=False,
cmap='viridis', background='on', label_color='r',
transpose=False, frame=False, colorbar=False, colorbar_bbox=None,
colorbar_orientation='vertical', colorbar_width=0.02, recompute_info=False,
ax=None, figure=None):
Expand All @@ -20,17 +20,21 @@ def plot_activity_map(recording, channel_ids=None, trange=None, cmap='viridis',
recording: RecordingExtractor
The recordng extractor object
channel_ids: list
The channel ids to display.
The channel ids to display
trange: list
List with start time and end time
activity: str
'rate' or 'amplitude'. If 'rate' the channel spike rate is used. If 'amplitude' the spike amplitude is used
log: bool
If True, log scale is used
cmap: matplotlib colormap
The colormap to be used (default 'viridis')
background: bool
If True, a background is added in between electrodes
transpose: bool, optional, default: False
Swap x and y channel coordinates if True.
Swap x and y channel coordinates if True
frame: bool, optional, default: False
Draw a frame around the array if True.
Draw a frame around the array if True
colorbar: bool
If True, a colorbar is displayed
colorbar_bbox: bbox
Expand All @@ -55,6 +59,8 @@ def plot_activity_map(recording, channel_ids=None, trange=None, cmap='viridis',
recording=recording,
channel_ids=channel_ids,
trange=trange,
activity=activity,
log=log,
background=background,
cmap=cmap,
label_color=label_color,
Expand All @@ -73,13 +79,15 @@ def plot_activity_map(recording, channel_ids=None, trange=None, cmap='viridis',


class ActivityMapWidget(BaseWidget):
def __init__(self, recording, channel_ids, trange, cmap, background, label_color='r', transpose=False, frame=False,
colorbar=False, colorbar_bbox=None, colorbar_orientation='vertical', colorbar_width=0.02,
recompute_info=False, figure=None, ax=None):
def __init__(self, recording, channel_ids, activity, log, trange, cmap, background, label_color='r',
transpose=False, frame=False, colorbar=False, colorbar_bbox=None, colorbar_orientation='vertical',
colorbar_width=0.02, recompute_info=False, figure=None, ax=None):
BaseWidget.__init__(self, figure, ax)
self._recording = recording
self._channel_ids = channel_ids
self._activity = activity
self._trange = trange
self._log = log
self._transpose = transpose
self._cmap = cmap
self._frame = frame
Expand All @@ -92,6 +100,7 @@ def __init__(self, recording, channel_ids, trange, cmap, background, label_color
self._recompute_info = recompute_info
self.colorbar = None
self.name = 'ActivityMap'
assert activity in ['rate', 'amplitude'], "'activity' can be either 'rate' or 'amplitude'"
assert 'location' in self._recording.get_shared_channel_property_names(), "Activity map requires 'location'" \
"property"

Expand All @@ -106,13 +115,14 @@ def _do_plot(self):
self._trange = [int(t * self._recording.get_sampling_frequency()) for t in self._trange]

locations = self._recording.get_channel_locations(channel_ids=self._channel_ids)
activity = st.postprocessing.compute_channel_spiking_activity(self._recording,
start_frame=self._trange[0],
end_frame=self._trange[1],
method='detection',
align=False,
recompute_info=self._recompute_info,
verbose=False)
spike_rates, spike_amplitudes = st.postprocessing.compute_channel_spiking_activity(self._recording,
start_frame=self._trange[0],
end_frame=self._trange[1],
method='detection',
align=False,
recompute_info=
self._recompute_info,
verbose=False)
if self._transpose:
locations = np.roll(locations, 1, axis=1)

Expand Down Expand Up @@ -143,10 +153,21 @@ def _do_plot(self):
elec_x = 0.9 * pitch_x
elec_y = 0.9 * pitch_y

if self._activity == 'rate':
activity = spike_rates
else: # amplitude
activity = np.abs(spike_amplitudes)

max_activity = np.round(np.max(activity), 2)
min_activity = np.round(np.min(activity), 2)

if self._log:
if np.any(activity < 1):
activity += (1 - np.min(activity))
activity = np.log(activity)

# normalize
activity -= np.min(activity)
activity -= (np.min(activity) + 1e-5)
activity /= np.ptp(activity)

for (loc, act, ch) in zip(locations, activity, self._recording.get_channel_ids()):
Expand Down Expand Up @@ -198,17 +219,22 @@ def _do_plot(self):

cax = inset_axes(self.ax, width="100%", height="100%", bbox_to_anchor=bbox,
bbox_transform=self.ax.transData)
scalable = mpl.cm.ScalarMappable(norm=mpl.colors.Normalize(vmin=0, vmax=1), cmap=self._cmap)
if self._log:
norm = mpl.colors.LogNorm(vmin=1e-5, vmax=1)
else:
norm = mpl.colors.Normalize(vmin=0, vmax=1)
scalable = mpl.cm.ScalarMappable(norm=norm, cmap=self._cmap)
self.colorbar = self.figure.colorbar(scalable, cax=cax,
orientation=self._colorbar_orient)#, shrink=0.5)
cax.yaxis.set_ticks_position('left')
cax.yaxis.set_label_position('left')
self.colorbar.set_ticks((0, 1))
self.colorbar.set_ticklabels((0, max_activity))
self.colorbar.set_ticklabels((min_activity, max_activity))
if self._colorbar_orient == 'vertical':
rotation = 90
else:
rotation = 0
self.colorbar.set_label('Sp/s', rotation=rotation)


if self._activity == 'rate':
self.colorbar.set_label('Sp/s', rotation=rotation)
else:
self.colorbar.set_label('Amp.', rotation=rotation)
3 changes: 3 additions & 0 deletions spikewidgets/widgets/mapswidget/templatemapswidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,10 @@ def _do_plot(self):
for i, (template, unit) in enumerate(zip(templates, unit_ids)):
ax = self.get_tiled_ax(i, nrows, ncols)
temp_map = np.abs(fun(template, axis=1))

if self._log:
if np.any(temp_map < 1):
temp_map += (1 - np.min(temp_map))
temp_map = np.log(temp_map)

# normalize
Expand Down

0 comments on commit 0c27a8c

Please sign in to comment.