From 897a60001ae9565948cdd28eaf2d5b0cc0b7eb44 Mon Sep 17 00:00:00 2001 From: mjpelah Date: Tue, 12 Jul 2022 15:27:20 -0400 Subject: [PATCH 1/7] Upgrade plot_morphology --- hnn_core/network.py | 20 ++++++++++++++++- hnn_core/viz.py | 54 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 1 deletion(-) diff --git a/hnn_core/network.py b/hnn_core/network.py index 717142f51..bdf91f19c 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -17,7 +17,7 @@ from .drives import _check_drive_parameter_values, _check_poisson_rates from .cells_default import pyramidal, basket from .params import _long_name, _short_name -from .viz import plot_cells +from .viz import plot_cell_morphologies, plot_cells from .externals.mne import _validate_type, _check_option from .extracellular import ExtracellularArray from .check import _check_gids, _gid_to_type, _string_input_to_list @@ -1327,6 +1327,24 @@ def plot_cells(self, ax=None, show=True): """ return plot_cells(net=self, ax=ax, show=show) + def plot_cell_morphologies(self, ax=None, show=True): + """Plot the morphology of the network cells + + Parameters + ---------- + ax : instance of matplotlib Axes3D | None + An axis object from matplotlib. If None, + a new figure is created. + show : bool + If True, show the figure. + + Returns + ------- + fig : instance of matplotlib Figure + The matplotlib figure handle. + """ + return plot_cell_morphologies(net=self, ax=ax,) + class _Connectivity(dict): """A class for containing the connectivity details of the network diff --git a/hnn_core/viz.py b/hnn_core/viz.py index 3b3b9e08d..e3557e7ee 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -1165,3 +1165,57 @@ def plot_laminar_csd(times, data, contact_labels, ax=None, colorbar=True, plt_show(show) return ax.get_figure() + + +def _plot_cell(ax, cell_type=None, show=True): + """Plot the cell morphology of a specific cell type + + parameters + ---------- + cell_type : instance of net.cell_type[] + The type of cell to be plotted. If None, + generic cell type + ax : instance of Axes3D + Matplotlib 3D axis + show : bool + if True, show the plot + + """ + + import matplotlib.pyplot as plt + from mpl_toolkits.mplot3d import Axes3D + + if ax is none: + plt.figure() + ax = plt.axes(projection='3d') + + return ax + + +def plot_cell_morphologies(net, ax=None, show=true): + """Plot the morphology of the network cells + + Parameters + ---------- + net : instance of Network + The network object + ax : instance of matplotlib Axes3D | None + An axis object from matplotlib. If none, + a new figure is created. + Show : bool + If True, show the figure + + Returns + ------- + fig : instance of matplotlib figure + The matplotlib figure handle + """ + + import matplotlib.pyplot as plt + from mpl_toolkits.mplot3d import Axes3D # noqa: F401 unused import + + if ax is None: + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + + return ax.get_figure() From dc88758e51992a1f6a8d3067db70bf767be5aa03 Mon Sep 17 00:00:00 2001 From: mjpelah Date: Fri, 15 Jul 2022 15:44:08 -0400 Subject: [PATCH 2/7] Added plotting of network cells --- hnn_core/network.py | 5 +- hnn_core/viz.py | 110 +++++++++++++++++++++++++------------------- 2 files changed, 65 insertions(+), 50 deletions(-) diff --git a/hnn_core/network.py b/hnn_core/network.py index bdf91f19c..a8158fc0a 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -17,7 +17,7 @@ from .drives import _check_drive_parameter_values, _check_poisson_rates from .cells_default import pyramidal, basket from .params import _long_name, _short_name -from .viz import plot_cell_morphologies, plot_cells +from .viz import plot_cell_morphologies, plot_cells, plot_cell_morphology from .externals.mne import _validate_type, _check_option from .extracellular import ExtracellularArray from .check import _check_gids, _gid_to_type, _string_input_to_list @@ -1343,7 +1343,8 @@ def plot_cell_morphologies(self, ax=None, show=True): fig : instance of matplotlib Figure The matplotlib figure handle. """ - return plot_cell_morphologies(net=self, ax=ax,) + + return plot_cell_morphologies(self, ax=ax, show=show) class _Connectivity(dict): diff --git a/hnn_core/viz.py b/hnn_core/viz.py index e3557e7ee..14ab437c9 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -809,7 +809,6 @@ def _linewidth_from_data_units(ax, linewidth): def plot_cell_morphology(cell, ax, show=True): """Plot the cell morphology. - Parameters ---------- cell : instance of Cell @@ -818,7 +817,6 @@ def plot_cell_morphology(cell, ax, show=True): Matplotlib 3D axis show : bool If True, show the plot - Returns ------- axes : list of instance of Axes3D @@ -835,6 +833,67 @@ def plot_cell_morphology(cell, ax, show=True): ax.set_xlim((cell.pos[1] - 250, cell.pos[1] + 150)) ax.set_zlim((cell.pos[2] - 100, cell.pos[2] + 1200)) + ax = _plot_cell(cell, ax=ax, plt=plt, show=True) + + plt.tight_layout() + plt_show(show) + return ax + +def plot_cell_morphologies(net, ax=None, show=True): + """Plot the morphology of the network cells + + Parameters + ---------- + net : instance of Network + The network object + ax : instance of matplotlib Axes3D | None + An axis object from matplotlib. If none, + a new figure is created. + Show : bool + If True, show the figure + + Returns + ------- + fig : instance of matplotlib figure + The matplotlib figure handle + """ + + import matplotlib.pyplot as plt + from mpl_toolkits.mplot3d import Axes3D # noqa: F401 unused import + + if ax is None: + plt.figure() + ax = plt.axes(projection='3d') + + colors = ['b', 'c', 'r', 'm'] + i=0 + + ax.set_xlim((list(net.cell_types.values())[0].pos[1] - 250, list(net.cell_types.values())[0].pos[1] + 150)) + ax.set_zlim((list(net.cell_types.values())[0].pos[2] - 100, list(net.cell_types.values())[0].pos[2] + 1200)) + + for cell in net.cell_types.values(): + ax = _plot_cell(cell, ax=ax, plt=plt, color=colors[i], show=True) + i+=1 + + return ax + +def _plot_cell(cell, ax, plt, color='b', show=True): + """Plot the cell morphology of a specific cell type + + parameters + ---------- + cell_type : instance of net.cell_type[] + The type of cell to be plotted. If None, + generic cell type + ax : instance of Axes3D + Matplotlib 3D axis + show : bool + if True, show the plot + + """ + + from mpl_toolkits.mplot3d import Axes3D # noqa + for sec_name, section in cell.sections.items(): linewidth = _linewidth_from_data_units(ax, section.diam) end_pts = section.end_pts @@ -846,7 +905,7 @@ def plot_cell_morphology(cell, ax, show=True): xs.append(pt[0] + dx) ys.append(pt[1] + dz) zs.append(pt[2] + dy) - ax.plot(xs, ys, zs, 'b-', linewidth=linewidth) + ax.plot(xs, ys, zs, 'b-', color = color, linewidth=linewidth) ax.view_init(0, -90) ax.axis('off') @@ -1122,51 +1181,6 @@ def _onclick(event): return ax.get_figure() -def plot_laminar_csd(times, data, contact_labels, ax=None, colorbar=True, - show=True): - """Plot laminar current source density (CSD) estimation from LFP array. - - Parameters - ---------- - times : Numpy array, shape (n_times,) - Sampling times (in ms). - data : array-like, shape (n_channels, n_times) - CSD data, channels x time. - ax : instance of matplotlib figure | None - The matplotlib axis. - colorbar : bool - If the colorbar is presented. - contact_labels : list - Labels associated with the contacts to plot. Passed as-is to - :func:`~matplotlib.axes.Axes.set_yticklabels`. - show : bool - If True, show the plot. - - Returns - ------- - fig : instance of matplotlib Figure - The matplotlib figure handle. - """ - import matplotlib.pyplot as plt - if ax is None: - _, ax = plt.subplots(1, 1, constrained_layout=True) - - im = ax.pcolormesh(times, contact_labels, np.array(data), - cmap="jet_r", shading='auto') - ax.set_title("CSD") - - if colorbar: - color_axis = ax.inset_axes([1.05, 0, 0.02, 1], transform=ax.transAxes) - plt.colorbar(im, ax=ax, cax=color_axis).set_label(r'$CSD (uV/um^{2})$') - - ax.set_xlabel('Time (ms)') - ax.set_ylabel('Electrode depth') - plt.tight_layout() - plt_show(show) - - return ax.get_figure() - - def _plot_cell(ax, cell_type=None, show=True): """Plot the cell morphology of a specific cell type From 74b1cfa7adc1e2d965d69d84814da38bbf5a6820 Mon Sep 17 00:00:00 2001 From: mjpelah Date: Wed, 20 Jul 2022 13:34:57 -0400 Subject: [PATCH 3/7] Offset and reorganization --- hnn_core/viz.py | 136 ++++++++++++++++++++---------------------------- 1 file changed, 57 insertions(+), 79 deletions(-) diff --git a/hnn_core/viz.py b/hnn_core/viz.py index 14ab437c9..ccb54aa0d 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -809,6 +809,7 @@ def _linewidth_from_data_units(ax, linewidth): def plot_cell_morphology(cell, ax, show=True): """Plot the cell morphology. + Parameters ---------- cell : instance of Cell @@ -817,6 +818,7 @@ def plot_cell_morphology(cell, ax, show=True): Matplotlib 3D axis show : bool If True, show the plot + Returns ------- axes : list of instance of Axes3D @@ -824,91 +826,67 @@ def plot_cell_morphology(cell, ax, show=True): """ import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D # noqa + cell_list = list() + clr_index=0 + colors = ['b', 'c', 'r', 'm'] + multiple = 0 if ax is None: plt.figure() ax = plt.axes(projection='3d') - # Cell is in XZ plane - ax.set_xlim((cell.pos[1] - 250, cell.pos[1] + 150)) - ax.set_zlim((cell.pos[2] - 100, cell.pos[2] + 1200)) - - ax = _plot_cell(cell, ax=ax, plt=plt, show=True) - - plt.tight_layout() - plt_show(show) - return ax - -def plot_cell_morphologies(net, ax=None, show=True): - """Plot the morphology of the network cells - - Parameters - ---------- - net : instance of Network - The network object - ax : instance of matplotlib Axes3D | None - An axis object from matplotlib. If none, - a new figure is created. - Show : bool - If True, show the figure - - Returns - ------- - fig : instance of matplotlib figure - The matplotlib figure handle - """ - - import matplotlib.pyplot as plt - from mpl_toolkits.mplot3d import Axes3D # noqa: F401 unused import - - if ax is None: - plt.figure() - ax = plt.axes(projection='3d') - - colors = ['b', 'c', 'r', 'm'] - i=0 - - ax.set_xlim((list(net.cell_types.values())[0].pos[1] - 250, list(net.cell_types.values())[0].pos[1] + 150)) - ax.set_zlim((list(net.cell_types.values())[0].pos[2] - 100, list(net.cell_types.values())[0].pos[2] + 1200)) - - for cell in net.cell_types.values(): - ax = _plot_cell(cell, ax=ax, plt=plt, color=colors[i], show=True) - i+=1 - - return ax - -def _plot_cell(cell, ax, plt, color='b', show=True): - """Plot the cell morphology of a specific cell type - - parameters - ---------- - cell_type : instance of net.cell_type[] - The type of cell to be plotted. If None, - generic cell type - ax : instance of Axes3D - Matplotlib 3D axis - show : bool - if True, show the plot - - """ - - from mpl_toolkits.mplot3d import Axes3D # noqa - - for sec_name, section in cell.sections.items(): - linewidth = _linewidth_from_data_units(ax, section.diam) - end_pts = section.end_pts - xs, ys, zs = list(), list(), list() - for pt in end_pts: - dx = cell.pos[0] - cell.sections['soma'].end_pts[0][0] - dy = cell.pos[1] - cell.sections['soma'].end_pts[0][1] - dz = cell.pos[2] - cell.sections['soma'].end_pts[0][2] - xs.append(pt[0] + dx) - ys.append(pt[1] + dz) - zs.append(pt[2] + dy) - ax.plot(xs, ys, zs, 'b-', color = color, linewidth=linewidth) - ax.view_init(0, -90) - ax.axis('off') + if type(cell) is dict: + for ind_cell in cell: + cell_list = list(cell.values()) + print("is dict") + else: + print("is not dict") + cell_list[0]=cell + # Cell is in XZ plane + #ax.set_xlim((cell_list[0].pos[1] - 250, cell_list[0].pos[1] + 150)) + #ax.set_zlim((cell_list[0].pos[2] - 100, cell_list[0].pos[2] + 1200)) + cell_radii = [0] + total_radius=0 + for clr_index, cell in enumerate(cell_list): + radius = 0 + + # Calculating the radius for cell offset + for sec_name, section in cell.sections.items(): + end_pts = section.end_pts + xs, ys, zs = list(), list(), list() + for pt in end_pts: + dx = cell.pos[0] - cell.sections['soma'].end_pts[0][0] + dy = cell.pos[1] - cell.sections['soma'].end_pts[0][1] + dz = cell.pos[2] - cell.sections['soma'].end_pts[0][2] + if radius < pt[0]: + radius = pt[0] + total_radius+=radius + + # Plotting the cell + for sec_name, section in cell.sections.items(): + ax.set_xlim((total_radius+200)) + ax.set_zlim((cell.pos[2] - 100, cell.pos[2] + 1200)) + linewidth = _linewidth_from_data_units(ax, section.diam) + end_pts = section.end_pts + xs, ys, zs = list(), list(), list() + for pt in end_pts: + dx = cell.pos[0] - cell.sections['soma'].end_pts[0][0] + dy = cell.pos[1] - cell.sections['soma'].end_pts[0][1] + dz = cell.pos[2] - cell.sections['soma'].end_pts[0][2] + xs.append(pt[0] + dx + ((radius + cell_radii[-1])*multiple)+100) + ys.append(pt[1] + dz) + zs.append(pt[2] + dy) + ax.plot(xs, ys, zs, color=colors[clr_index], linewidth=linewidth) + cell_radii.append(radius) + ax.view_init(0, -90) + ax.axis('on') + ax.grid('off') + ax.set_yticks([]) + ax.set_xticks([]) + print("iteration: ") + print(multiple) + multiple = multiple + 1 plt.tight_layout() plt_show(show) return ax From 9c799dd2643f1b21646b065b8f2cc84344eda824 Mon Sep 17 00:00:00 2001 From: mjpelah Date: Wed, 20 Jul 2022 14:01:21 -0400 Subject: [PATCH 4/7] cleaned, general update --- .../plot_record_extracellular_potentials.py | 4 ++++ hnn_core/network.py | 21 +++-------------- hnn_core/viz.py | 23 ++++++------------- 3 files changed, 14 insertions(+), 34 deletions(-) diff --git a/examples/howto/plot_record_extracellular_potentials.py b/examples/howto/plot_record_extracellular_potentials.py index 7c4a5212c..017eeda70 100644 --- a/examples/howto/plot_record_extracellular_potentials.py +++ b/examples/howto/plot_record_extracellular_potentials.py @@ -44,6 +44,10 @@ # each electrode. net.plot_cells() +############################################################################### +# Plotting the cell morphologies of the network cells +net.plot_cell_morphologies() + ############################################################################### # The default network consists of 2 layers (L2 and L5), within which the cell # somas are arranged in a regular grid, and apical dendrites are aligned along diff --git a/hnn_core/network.py b/hnn_core/network.py index a8158fc0a..716f63a6d 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -17,7 +17,7 @@ from .drives import _check_drive_parameter_values, _check_poisson_rates from .cells_default import pyramidal, basket from .params import _long_name, _short_name -from .viz import plot_cell_morphologies, plot_cells, plot_cell_morphology +from .viz import plot_cells, plot_cell_morphology from .externals.mne import _validate_type, _check_option from .extracellular import ExtracellularArray from .check import _check_gids, _gid_to_type, _string_input_to_list @@ -1327,24 +1327,9 @@ def plot_cells(self, ax=None, show=True): """ return plot_cells(net=self, ax=ax, show=show) - def plot_cell_morphologies(self, ax=None, show=True): - """Plot the morphology of the network cells - - Parameters - ---------- - ax : instance of matplotlib Axes3D | None - An axis object from matplotlib. If None, - a new figure is created. - show : bool - If True, show the figure. - - Returns - ------- - fig : instance of matplotlib Figure - The matplotlib figure handle. - """ - return plot_cell_morphologies(self, ax=ax, show=show) + def plot_cell_morphologies(self, ax=None, show=True): + return plot_cell_morphology(self.cell_types, ax=ax, show=show) class _Connectivity(dict): diff --git a/hnn_core/viz.py b/hnn_core/viz.py index ccb54aa0d..da8c3eca7 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -827,9 +827,8 @@ def plot_cell_morphology(cell, ax, show=True): import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D # noqa cell_list = list() - clr_index=0 colors = ['b', 'c', 'r', 'm'] - multiple = 0 + clr_index=0 if ax is None: plt.figure() @@ -838,20 +837,18 @@ def plot_cell_morphology(cell, ax, show=True): if type(cell) is dict: for ind_cell in cell: cell_list = list(cell.values()) - print("is dict") else: - print("is not dict") cell_list[0]=cell # Cell is in XZ plane #ax.set_xlim((cell_list[0].pos[1] - 250, cell_list[0].pos[1] + 150)) #ax.set_zlim((cell_list[0].pos[2] - 100, cell_list[0].pos[2] + 1200)) - cell_radii = [0] - total_radius=0 + cell_radii = list() + cell_radii.append(clr_index) for clr_index, cell in enumerate(cell_list): - radius = 0 # Calculating the radius for cell offset + radius = 0 for sec_name, section in cell.sections.items(): end_pts = section.end_pts xs, ys, zs = list(), list(), list() @@ -861,11 +858,11 @@ def plot_cell_morphology(cell, ax, show=True): dz = cell.pos[2] - cell.sections['soma'].end_pts[0][2] if radius < pt[0]: radius = pt[0] - total_radius+=radius + cell_radii.append(radius) # Plotting the cell for sec_name, section in cell.sections.items(): - ax.set_xlim((total_radius+200)) + ax.set_xlim((sum(cell_radii, 100))) ax.set_zlim((cell.pos[2] - 100, cell.pos[2] + 1200)) linewidth = _linewidth_from_data_units(ax, section.diam) end_pts = section.end_pts @@ -874,19 +871,13 @@ def plot_cell_morphology(cell, ax, show=True): dx = cell.pos[0] - cell.sections['soma'].end_pts[0][0] dy = cell.pos[1] - cell.sections['soma'].end_pts[0][1] dz = cell.pos[2] - cell.sections['soma'].end_pts[0][2] - xs.append(pt[0] + dx + ((radius + cell_radii[-1])*multiple)+100) + xs.append(pt[0] + dx + (radius + cell_radii[-1]+100)) ys.append(pt[1] + dz) zs.append(pt[2] + dy) ax.plot(xs, ys, zs, color=colors[clr_index], linewidth=linewidth) - cell_radii.append(radius) ax.view_init(0, -90) ax.axis('on') ax.grid('off') - ax.set_yticks([]) - ax.set_xticks([]) - print("iteration: ") - print(multiple) - multiple = multiple + 1 plt.tight_layout() plt_show(show) return ax From c64ced3ece34d1f80f4307280070d16b9b4d5cf6 Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Sun, 15 Jan 2023 15:40:03 -0500 Subject: [PATCH 5/7] Rebase fix Co-authored-by: mjpelah --- hnn_core/viz.py | 118 +++++++++++++++--------------------------------- 1 file changed, 37 insertions(+), 81 deletions(-) diff --git a/hnn_core/viz.py b/hnn_core/viz.py index da8c3eca7..caa86d70a 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -60,9 +60,7 @@ def _decimate_plot_data(decim, data, times, sfreq=None): def plt_show(show=True, fig=None, **kwargs): """Show a figure while suppressing warnings. - NB copied from :func:`mne.viz.utils.plt_show`. - Parameters ---------- show : bool @@ -82,7 +80,6 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None, ax=None, decim=None, color='cividis', voltage_offset=50, voltage_scalebar=200, show=True): """Plot laminar extracellular electrode array voltage time series. - Parameters ---------- times : array-like, shape (n_times,) @@ -117,7 +114,6 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None, :func:`~matplotlib.axes.Axes.set_yticklabels`. show : bool If True, show the figure - Returns ------- fig : instance of plt.fig @@ -223,7 +219,6 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None, def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None, color='k', label="average", average=False, show=True): """Simple layer-specific plot function. - Parameters ---------- dpl : instance of Dipole | list of Dipole instances @@ -250,7 +245,6 @@ def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None, If True, render the average across all dpls. show : bool If True, show the figure - Returns ------- fig : instance of plt.fig @@ -322,7 +316,6 @@ def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None, def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None, show=True): """Plot the histogram of spiking activity across trials. - Parameters ---------- cell_response : instance of CellResponse @@ -334,25 +327,17 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None, a new figure is created. spike_types: string | list | dictionary | None String input of a valid spike type is plotted individually. - | Ex: ``'poisson'``, ``'evdist'``, ``'evprox'``, ... - List of valid string inputs will plot each spike type individually. - | Ex: ``['poisson', 'evdist']`` - Dictionary of valid lists will plot list elements as a group. - | Ex: ``{'Evoked': ['evdist', 'evprox'], 'Tonic': ['poisson']}`` - If None, all input spike types are plotted individually if any are present. Otherwise spikes from all cells are plotted. Valid strings also include leading characters of spike types - | Ex: ``'ev'`` is equivalent to ``['evdist', 'evprox']`` show : bool If True, show the figure. - Returns ------- fig : instance of matplotlib Figure @@ -444,7 +429,6 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None, def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True): """Plot the aggregate spiking activity according to cell type. - Parameters ---------- cell_response : instance of CellResponse @@ -455,7 +439,6 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True): An axis object from matplotlib. If None, a new figure is created. show : bool If True, show the figure. - Returns ------- fig : instance of matplotlib Figure @@ -520,7 +503,6 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True): def plot_cells(net, ax=None, show=True): """Plot the cells using Network.pos_dict. - Parameters ---------- net : instance of Network @@ -530,7 +512,6 @@ def plot_cells(net, ax=None, show=True): a new figure is created. show : bool If True, show the figure. - Returns ------- fig : instance of matplotlib Figure @@ -577,7 +558,6 @@ def plot_tfr_morlet(dpl, freqs, *, n_cycles=7., tmin=None, tmax=None, colormap='inferno', colorbar=True, colorbar_inside=False, show=True): """Plot Morlet time-frequency representation of dipole time course - Parameters ---------- dpl : instance of Dipole | list of Dipole instances @@ -612,7 +592,6 @@ def plot_tfr_morlet(dpl, freqs, *, n_cycles=7., tmin=None, tmax=None, Put the color inside the heatmap if True. show : bool If True, show the figure - Returns ------- fig : instance of matplotlib Figure @@ -716,13 +695,11 @@ def plot_tfr_morlet(dpl, freqs, *, n_cycles=7., tmin=None, tmax=None, def plot_psd(dpl, *, fmin=0, fmax=None, tmin=None, tmax=None, layer='agg', color=None, label=None, ax=None, show=True): """Plot power spectral density (PSD) of dipole time course - Applies `~scipy.signal.periodogram` from SciPy with ``window='hamming'``. Note that no spectral averaging is applied across time, as most ``hnn_core`` simulations are short-duration. However, passing a list of `Dipole` instances will plot their average (Hamming-windowed) power, which resembles the `Welch`-method applied over time. - Parameters ---------- dpl : instance of Dipole | list of Dipole instances @@ -745,7 +722,6 @@ def plot_psd(dpl, *, fmin=0, fmax=None, tmin=None, tmax=None, layer='agg', The matplotlib axis. show : bool If True, show the figure - Returns ------- fig : instance of matplotlib Figure @@ -809,7 +785,6 @@ def _linewidth_from_data_units(ax, linewidth): def plot_cell_morphology(cell, ax, show=True): """Plot the cell morphology. - Parameters ---------- cell : instance of Cell @@ -818,7 +793,6 @@ def plot_cell_morphology(cell, ax, show=True): Matplotlib 3D axis show : bool If True, show the plot - Returns ------- axes : list of instance of Axes3D @@ -828,7 +802,7 @@ def plot_cell_morphology(cell, ax, show=True): from mpl_toolkits.mplot3d import Axes3D # noqa cell_list = list() colors = ['b', 'c', 'r', 'm'] - clr_index=0 + clr_index = 0 if ax is None: plt.figure() @@ -838,11 +812,11 @@ def plot_cell_morphology(cell, ax, show=True): for ind_cell in cell: cell_list = list(cell.values()) else: - cell_list[0]=cell + cell_list[0] = cell # Cell is in XZ plane - #ax.set_xlim((cell_list[0].pos[1] - 250, cell_list[0].pos[1] + 150)) - #ax.set_zlim((cell_list[0].pos[2] - 100, cell_list[0].pos[2] + 1200)) + # ax.set_xlim((cell_list[0].pos[1] - 250, cell_list[0].pos[1] + 150)) + # ax.set_zlim((cell_list[0].pos[2] - 100, cell_list[0].pos[2] + 1200)) cell_radii = list() cell_radii.append(clr_index) for clr_index, cell in enumerate(cell_list): @@ -871,7 +845,7 @@ def plot_cell_morphology(cell, ax, show=True): dx = cell.pos[0] - cell.sections['soma'].end_pts[0][0] dy = cell.pos[1] - cell.sections['soma'].end_pts[0][1] dz = cell.pos[2] - cell.sections['soma'].end_pts[0][2] - xs.append(pt[0] + dx + (radius + cell_radii[-1]+100)) + xs.append(pt[0] + dx + (radius + cell_radii[-1] + 100)) ys.append(pt[1] + dz) zs.append(pt[2] + dy) ax.plot(xs, ys, zs, color=colors[clr_index], linewidth=linewidth) @@ -887,7 +861,6 @@ def plot_connectivity_matrix(net, conn_idx, ax=None, show_weight=True, colorbar=True, colormap='Greys', show=True): """Plot connectivity matrix with color bar for synaptic weights - Parameters ---------- net : Instance of Network object @@ -906,7 +879,6 @@ def plot_connectivity_matrix(net, conn_idx, ax=None, show_weight=True, If True (default), adjust figure to include colorbar. show : bool If True, show the plot - Returns ------- fig : instance of matplotlib Figure @@ -1017,12 +989,10 @@ def _update_target_plot(ax, conn, src_gid, src_type_pos, target_type_pos, def plot_cell_connectivity(net, conn_idx, src_gid=None, axes=None, colorbar=True, colormap='viridis', show=True): """Plot synaptic weight of connections. - This is an interactive plot with source cells shown in the left subplot and connectivity from a source cell to all the target cells in the right subplot. Click on the cells in the left subplot to explore how the connectivity pattern changes for different source cells. - Parameters ---------- net : Instance of Network object @@ -1041,12 +1011,10 @@ def plot_cell_connectivity(net, conn_idx, src_gid=None, axes=None, If True (default), adjust figure to include colorbar. show : bool If True, show the plot - Returns ------- fig : instance of matplotlib Figure The matplotlib figure handle. - Notes ----- Target cells will be determined by the connections in @@ -1055,7 +1023,6 @@ def plot_cell_connectivity(net, conn_idx, src_gid=None, axes=None, it will appear as a smaller black cross. Source cell is plotted as a red square. Source cell will not be plotted if the connection corresponds to a drive, ex: poisson, bursty, etc. - """ import matplotlib.pyplot as plt from .network import Network @@ -1150,55 +1117,44 @@ def _onclick(event): return ax.get_figure() -def _plot_cell(ax, cell_type=None, show=True): - """Plot the cell morphology of a specific cell type - - parameters - ---------- - cell_type : instance of net.cell_type[] - The type of cell to be plotted. If None, - generic cell type - ax : instance of Axes3D - Matplotlib 3D axis - show : bool - if True, show the plot - - """ - - import matplotlib.pyplot as plt - from mpl_toolkits.mplot3d import Axes3D - - if ax is none: - plt.figure() - ax = plt.axes(projection='3d') - - return ax - - -def plot_cell_morphologies(net, ax=None, show=true): - """Plot the morphology of the network cells - +def plot_laminar_csd(times, data, contact_labels, ax=None, colorbar=True, + show=True): + """Plot laminar current source density (CSD) estimation from LFP array. Parameters ---------- - net : instance of Network - The network object - ax : instance of matplotlib Axes3D | None - An axis object from matplotlib. If none, - a new figure is created. - Show : bool - If True, show the figure - + times : Numpy array, shape (n_times,) + Sampling times (in ms). + data : array-like, shape (n_channels, n_times) + CSD data, channels x time. + ax : instance of matplotlib figure | None + The matplotlib axis. + colorbar : bool + If the colorbar is presented. + contact_labels : list + Labels associated with the contacts to plot. Passed as-is to + :func:`~matplotlib.axes.Axes.set_yticklabels`. + show : bool + If True, show the plot. Returns ------- - fig : instance of matplotlib figure - The matplotlib figure handle + fig : instance of matplotlib Figure + The matplotlib figure handle. """ - import matplotlib.pyplot as plt - from mpl_toolkits.mplot3d import Axes3D # noqa: F401 unused import - if ax is None: - fig = plt.figure() - ax = fig.add_subplot(111, projection='3d') + _, ax = plt.subplots(1, 1, constrained_layout=True) + + im = ax.pcolormesh(times, contact_labels, np.array(data), + cmap="jet_r", shading='auto') + ax.set_title("CSD") + + if colorbar: + color_axis = ax.inset_axes([1.05, 0, 0.02, 1], transform=ax.transAxes) + plt.colorbar(im, ax=ax, cax=color_axis).set_label(r'$CSD (uV/um^{2})$') + + ax.set_xlabel('Time (ms)') + ax.set_ylabel('Electrode depth') + plt.tight_layout() + plt_show(show) return ax.get_figure() From 227687ca3ba9df611a3da4a2518935530976886d Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Sun, 15 Jan 2023 15:52:33 -0500 Subject: [PATCH 6/7] Fix rebase doc errors --- hnn_core/viz.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/hnn_core/viz.py b/hnn_core/viz.py index caa86d70a..c644bb46f 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -60,7 +60,9 @@ def _decimate_plot_data(decim, data, times, sfreq=None): def plt_show(show=True, fig=None, **kwargs): """Show a figure while suppressing warnings. + NB copied from :func:`mne.viz.utils.plt_show`. + Parameters ---------- show : bool @@ -80,6 +82,7 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None, ax=None, decim=None, color='cividis', voltage_offset=50, voltage_scalebar=200, show=True): """Plot laminar extracellular electrode array voltage time series. + Parameters ---------- times : array-like, shape (n_times,) @@ -114,6 +117,7 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None, :func:`~matplotlib.axes.Axes.set_yticklabels`. show : bool If True, show the figure + Returns ------- fig : instance of plt.fig @@ -219,6 +223,7 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None, def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None, color='k', label="average", average=False, show=True): """Simple layer-specific plot function. + Parameters ---------- dpl : instance of Dipole | list of Dipole instances @@ -245,6 +250,7 @@ def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None, If True, render the average across all dpls. show : bool If True, show the figure + Returns ------- fig : instance of plt.fig @@ -316,6 +322,7 @@ def plot_dipole(dpl, tmin=None, tmax=None, ax=None, layer='agg', decim=None, def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None, show=True): """Plot the histogram of spiking activity across trials. + Parameters ---------- cell_response : instance of CellResponse @@ -327,17 +334,24 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None, a new figure is created. spike_types: string | list | dictionary | None String input of a valid spike type is plotted individually. + | Ex: ``'poisson'``, ``'evdist'``, ``'evprox'``, ... + List of valid string inputs will plot each spike type individually. + | Ex: ``['poisson', 'evdist']`` Dictionary of valid lists will plot list elements as a group. + | Ex: ``{'Evoked': ['evdist', 'evprox'], 'Tonic': ['poisson']}`` + If None, all input spike types are plotted individually if any are present. Otherwise spikes from all cells are plotted. Valid strings also include leading characters of spike types + | Ex: ``'ev'`` is equivalent to ``['evdist', 'evprox']`` show : bool If True, show the figure. + Returns ------- fig : instance of matplotlib Figure @@ -429,6 +443,7 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None, def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True): """Plot the aggregate spiking activity according to cell type. + Parameters ---------- cell_response : instance of CellResponse @@ -439,6 +454,7 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True): An axis object from matplotlib. If None, a new figure is created. show : bool If True, show the figure. + Returns ------- fig : instance of matplotlib Figure @@ -503,6 +519,7 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True): def plot_cells(net, ax=None, show=True): """Plot the cells using Network.pos_dict. + Parameters ---------- net : instance of Network @@ -512,6 +529,7 @@ def plot_cells(net, ax=None, show=True): a new figure is created. show : bool If True, show the figure. + Returns ------- fig : instance of matplotlib Figure @@ -558,6 +576,7 @@ def plot_tfr_morlet(dpl, freqs, *, n_cycles=7., tmin=None, tmax=None, colormap='inferno', colorbar=True, colorbar_inside=False, show=True): """Plot Morlet time-frequency representation of dipole time course + Parameters ---------- dpl : instance of Dipole | list of Dipole instances @@ -592,6 +611,7 @@ def plot_tfr_morlet(dpl, freqs, *, n_cycles=7., tmin=None, tmax=None, Put the color inside the heatmap if True. show : bool If True, show the figure + Returns ------- fig : instance of matplotlib Figure @@ -695,11 +715,13 @@ def plot_tfr_morlet(dpl, freqs, *, n_cycles=7., tmin=None, tmax=None, def plot_psd(dpl, *, fmin=0, fmax=None, tmin=None, tmax=None, layer='agg', color=None, label=None, ax=None, show=True): """Plot power spectral density (PSD) of dipole time course + Applies `~scipy.signal.periodogram` from SciPy with ``window='hamming'``. Note that no spectral averaging is applied across time, as most ``hnn_core`` simulations are short-duration. However, passing a list of `Dipole` instances will plot their average (Hamming-windowed) power, which resembles the `Welch`-method applied over time. + Parameters ---------- dpl : instance of Dipole | list of Dipole instances @@ -722,6 +744,7 @@ def plot_psd(dpl, *, fmin=0, fmax=None, tmin=None, tmax=None, layer='agg', The matplotlib axis. show : bool If True, show the figure + Returns ------- fig : instance of matplotlib Figure @@ -785,6 +808,7 @@ def _linewidth_from_data_units(ax, linewidth): def plot_cell_morphology(cell, ax, show=True): """Plot the cell morphology. + Parameters ---------- cell : instance of Cell @@ -793,6 +817,7 @@ def plot_cell_morphology(cell, ax, show=True): Matplotlib 3D axis show : bool If True, show the plot + Returns ------- axes : list of instance of Axes3D @@ -861,6 +886,7 @@ def plot_connectivity_matrix(net, conn_idx, ax=None, show_weight=True, colorbar=True, colormap='Greys', show=True): """Plot connectivity matrix with color bar for synaptic weights + Parameters ---------- net : Instance of Network object @@ -879,6 +905,7 @@ def plot_connectivity_matrix(net, conn_idx, ax=None, show_weight=True, If True (default), adjust figure to include colorbar. show : bool If True, show the plot + Returns ------- fig : instance of matplotlib Figure @@ -993,6 +1020,7 @@ def plot_cell_connectivity(net, conn_idx, src_gid=None, axes=None, subplot and connectivity from a source cell to all the target cells in the right subplot. Click on the cells in the left subplot to explore how the connectivity pattern changes for different source cells. + Parameters ---------- net : Instance of Network object @@ -1011,10 +1039,12 @@ def plot_cell_connectivity(net, conn_idx, src_gid=None, axes=None, If True (default), adjust figure to include colorbar. show : bool If True, show the plot + Returns ------- fig : instance of matplotlib Figure The matplotlib figure handle. + Notes ----- Target cells will be determined by the connections in @@ -1120,6 +1150,7 @@ def _onclick(event): def plot_laminar_csd(times, data, contact_labels, ax=None, colorbar=True, show=True): """Plot laminar current source density (CSD) estimation from LFP array. + Parameters ---------- times : Numpy array, shape (n_times,) @@ -1135,6 +1166,7 @@ def plot_laminar_csd(times, data, contact_labels, ax=None, colorbar=True, :func:`~matplotlib.axes.Axes.set_yticklabels`. show : bool If True, show the plot. + Returns ------- fig : instance of matplotlib Figure From 49b577977068a8f5028440573b279fe007e8e57d Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Sun, 15 Jan 2023 15:56:26 -0500 Subject: [PATCH 7/7] Flake8 errors --- hnn_core/viz.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/hnn_core/viz.py b/hnn_core/viz.py index c644bb46f..56a19bb70 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -117,7 +117,7 @@ def plot_laminar_lfp(times, data, contact_labels, tmin=None, tmax=None, :func:`~matplotlib.axes.Axes.set_yticklabels`. show : bool If True, show the figure - + Returns ------- fig : instance of plt.fig @@ -340,6 +340,7 @@ def plot_spikes_hist(cell_response, trial_idx=None, ax=None, spike_types=None, List of valid string inputs will plot each spike type individually. | Ex: ``['poisson', 'evdist']`` + Dictionary of valid lists will plot list elements as a group. | Ex: ``{'Evoked': ['evdist', 'evprox'], 'Tonic': ['poisson']}`` @@ -1016,6 +1017,7 @@ def _update_target_plot(ax, conn, src_gid, src_type_pos, target_type_pos, def plot_cell_connectivity(net, conn_idx, src_gid=None, axes=None, colorbar=True, colormap='viridis', show=True): """Plot synaptic weight of connections. + This is an interactive plot with source cells shown in the left subplot and connectivity from a source cell to all the target cells in the right subplot. Click on the cells in the left subplot to @@ -1044,7 +1046,7 @@ def plot_cell_connectivity(net, conn_idx, src_gid=None, axes=None, ------- fig : instance of matplotlib Figure The matplotlib figure handle. - + Notes ----- Target cells will be determined by the connections in @@ -1166,7 +1168,7 @@ def plot_laminar_csd(times, data, contact_labels, ax=None, colorbar=True, :func:`~matplotlib.axes.Axes.set_yticklabels`. show : bool If True, show the plot. - + Returns ------- fig : instance of matplotlib Figure