diff --git a/examples/howto/plot_record_extracellular_potentials.py b/examples/howto/plot_record_extracellular_potentials.py index 7c4a5212c..017eeda70 100644 --- a/examples/howto/plot_record_extracellular_potentials.py +++ b/examples/howto/plot_record_extracellular_potentials.py @@ -44,6 +44,10 @@ # each electrode. net.plot_cells() +############################################################################### +# Plotting the cell morphologies of the network cells +net.plot_cell_morphologies() + ############################################################################### # The default network consists of 2 layers (L2 and L5), within which the cell # somas are arranged in a regular grid, and apical dendrites are aligned along diff --git a/hnn_core/network.py b/hnn_core/network.py index 717142f51..716f63a6d 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -17,7 +17,7 @@ from .drives import _check_drive_parameter_values, _check_poisson_rates from .cells_default import pyramidal, basket from .params import _long_name, _short_name -from .viz import plot_cells +from .viz import plot_cells, plot_cell_morphology from .externals.mne import _validate_type, _check_option from .extracellular import ExtracellularArray from .check import _check_gids, _gid_to_type, _string_input_to_list @@ -1328,6 +1328,10 @@ def plot_cells(self, ax=None, show=True): return plot_cells(net=self, ax=ax, show=show) + def plot_cell_morphologies(self, ax=None, show=True): + return plot_cell_morphology(self.cell_types, ax=ax, show=show) + + class _Connectivity(dict): """A class for containing the connectivity details of the network diff --git a/hnn_core/viz.py b/hnn_core/viz.py index 3b3b9e08d..56a19bb70 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -826,30 +826,58 @@ def plot_cell_morphology(cell, ax, show=True): """ import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D # noqa + cell_list = list() + colors = ['b', 'c', 'r', 'm'] + clr_index = 0 if ax is None: plt.figure() ax = plt.axes(projection='3d') - # Cell is in XZ plane - ax.set_xlim((cell.pos[1] - 250, cell.pos[1] + 150)) - ax.set_zlim((cell.pos[2] - 100, cell.pos[2] + 1200)) - - for sec_name, section in cell.sections.items(): - linewidth = _linewidth_from_data_units(ax, section.diam) - end_pts = section.end_pts - xs, ys, zs = list(), list(), list() - for pt in end_pts: - dx = cell.pos[0] - cell.sections['soma'].end_pts[0][0] - dy = cell.pos[1] - cell.sections['soma'].end_pts[0][1] - dz = cell.pos[2] - cell.sections['soma'].end_pts[0][2] - xs.append(pt[0] + dx) - ys.append(pt[1] + dz) - zs.append(pt[2] + dy) - ax.plot(xs, ys, zs, 'b-', linewidth=linewidth) - ax.view_init(0, -90) - ax.axis('off') + if type(cell) is dict: + for ind_cell in cell: + cell_list = list(cell.values()) + else: + cell_list[0] = cell + # Cell is in XZ plane + # ax.set_xlim((cell_list[0].pos[1] - 250, cell_list[0].pos[1] + 150)) + # ax.set_zlim((cell_list[0].pos[2] - 100, cell_list[0].pos[2] + 1200)) + cell_radii = list() + cell_radii.append(clr_index) + for clr_index, cell in enumerate(cell_list): + + # Calculating the radius for cell offset + radius = 0 + for sec_name, section in cell.sections.items(): + end_pts = section.end_pts + xs, ys, zs = list(), list(), list() + for pt in end_pts: + dx = cell.pos[0] - cell.sections['soma'].end_pts[0][0] + dy = cell.pos[1] - cell.sections['soma'].end_pts[0][1] + dz = cell.pos[2] - cell.sections['soma'].end_pts[0][2] + if radius < pt[0]: + radius = pt[0] + cell_radii.append(radius) + + # Plotting the cell + for sec_name, section in cell.sections.items(): + ax.set_xlim((sum(cell_radii, 100))) + ax.set_zlim((cell.pos[2] - 100, cell.pos[2] + 1200)) + linewidth = _linewidth_from_data_units(ax, section.diam) + end_pts = section.end_pts + xs, ys, zs = list(), list(), list() + for pt in end_pts: + dx = cell.pos[0] - cell.sections['soma'].end_pts[0][0] + dy = cell.pos[1] - cell.sections['soma'].end_pts[0][1] + dz = cell.pos[2] - cell.sections['soma'].end_pts[0][2] + xs.append(pt[0] + dx + (radius + cell_radii[-1] + 100)) + ys.append(pt[1] + dz) + zs.append(pt[2] + dy) + ax.plot(xs, ys, zs, color=colors[clr_index], linewidth=linewidth) + ax.view_init(0, -90) + ax.axis('on') + ax.grid('off') plt.tight_layout() plt_show(show) return ax @@ -1027,7 +1055,6 @@ def plot_cell_connectivity(net, conn_idx, src_gid=None, axes=None, it will appear as a smaller black cross. Source cell is plotted as a red square. Source cell will not be plotted if the connection corresponds to a drive, ex: poisson, bursty, etc. - """ import matplotlib.pyplot as plt from .network import Network