Skip to content

Commit

Permalink
update for plot module
Browse files Browse the repository at this point in the history
  • Loading branch information
tanliwei-coder committed Jul 26, 2024
1 parent 3e328ba commit 50632e2
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 73 deletions.
2 changes: 1 addition & 1 deletion stereo/plots/plot_cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def save_plot(
:param save_height: the height of the saved plot, defaults to be the same as the plot.
:param save_only_in_view: only save the plot in the view, defaults to False.
:param with_base_image: whether to save the plot with the base image, defaults to False.
Currently, the resolution of the saved base image may not be high.
Currently, the dpi of the saved base image may not be high.
"""
self._set_firefox_and_driver_path()
Expand Down
78 changes: 49 additions & 29 deletions stereo/plots/plot_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def marker_genes_volcano(
group_name: str,
res_key: Optional[str] = 'marker_genes',
hue_order: Optional[set] = ('down', 'normal', 'up'),
palette: Optional[str] = ("#377EB8", "grey", "#E41A1C"),
palette: Optional[Union[list, tuple]] = ("#377EB8", "grey", "#E41A1C"),
alpha: Optional[int] = 1,
dot_size: Optional[int] = 15,
text_genes: Optional[list] = None,
Expand All @@ -225,7 +225,8 @@ def marker_genes_volcano(
:param group_name: the group name.
:param res_key: the result key of marker gene.
:param hue_order: the classification method.
:param palette: the color theme.
:param palette: the color theme, a list of colors whose length is 3,
in which, each one respectively specifies the color of 'down', 'normal' and 'up' marker genes.
:param alpha: the opacity.
:param dot_size: the dot size.
:param text_genes: show gene names.
Expand Down Expand Up @@ -277,7 +278,7 @@ def genes_count(
:param y_label: list of y label.
:param ncols: the number of columns.
:param dot_size: the dot size.
:param palette: color theme.
:param palette: a single color specifying the color of markers.
:param width: the figure width in pixels.
:param height: the figure height in pixels.
:param out_path: the path to save the figure.
Expand Down Expand Up @@ -334,7 +335,7 @@ def spatial_scatter(
cells_key: Optional[list] = ["total_counts", "n_genes_by_counts"],
ncols: Optional[int] = 2,
dot_size: Optional[int] = None,
palette: Optional[str] = 'stereo',
palette: Optional[Union[str, list]] = 'stereo',
width: Optional[int] = None,
height: Optional[int] = None,
x_label: Optional[Union[list, str]] = 'spatial1',
Expand All @@ -350,7 +351,7 @@ def spatial_scatter(
:param cells_key: specified cells key list.
:param ncols: the number of plot columns.
:param dot_size: the dot size.
:param palette: the color theme.
:param palette: a palette name or a list of colors.
:param width: the figure width in pixels.
:param height: the figure height in pixels.
:param x_label: list of x label.
Expand Down Expand Up @@ -406,7 +407,7 @@ def spatial_scatter_by_gene(
self,
gene_name: Union[str, list, np.ndarray],
dot_size: Optional[int] = None,
palette: Optional[str] = 'CET_L4',
palette: Optional[Union[str, list]] = 'CET_L4',
color_bar_reverse: Optional[bool] = True,
width: Optional[int] = None,
height: Optional[int] = None,
Expand All @@ -421,7 +422,7 @@ def spatial_scatter_by_gene(
:param gene_name: a gene or a list of genes you want to show.
:param dot_size: the dot size, defaults to `None`.
:param palette: the color theme, defaults to `'CET_L4'`.
:param palette: a palette name or a list of colors, defaults to `'CET_L4'`.
:param color_bar_reverse: if True, reverse the color bar, defaults to False
:param width: the figure width in pixels.
:param height: the figure height in pixels.
Expand Down Expand Up @@ -480,7 +481,7 @@ def gaussian_smooth_scatter_by_gene(
self,
gene_name: str = None,
dot_size: Optional[int] = None,
palette: Optional[str] = 'CET_L4',
palette: Optional[Union[str, list]] = 'CET_L4',
color_bar_reverse: Optional[bool] = True,
width: Optional[int] = None,
height: Optional[int] = None,
Expand All @@ -496,7 +497,7 @@ def gaussian_smooth_scatter_by_gene(
:param gene_name: specify the gene you want to draw, if `None` by default, will select randomly.
:param dot_size: marker sizemarker size, defaults to `None`.
:param palette: Color theme, defaults to `'CET_L4'`.
:param palette: a palette name or a list of colors, defaults to `'CET_L4'`.
:param color_bar_reverse: if True, reverse the color bar, defaults to False
:param width: the figure width in pixels.
:param height: the figure height in pixels.
Expand Down Expand Up @@ -668,7 +669,7 @@ def batches_umap(
y_label: Optional[str] = 'umap2',
bfig_title: Optional[str] = 'all batches',
dot_size: Optional[int] = 1,
palette: Optional[Union[str, list]] = 'stereo_30',
palette: Optional[Union[str, list, dict]] = 'stereo_30',
width: Optional[int] = None,
height: Optional[int] = None
):
Expand All @@ -681,7 +682,8 @@ def batches_umap(
:param y_label: the y label.
:param bfig_title: the big figure title.
:param dot_size: the dot size.
:param palette: the color list.
:param palette: a palette name, a list of colors whose length is equal to the batches,
or a dict whose keys are batch numbers and values are colors.
:param width: the figure width in pixels.
:param height: the figure height in pixels.
Expand All @@ -707,7 +709,7 @@ def batches_umap(
umap_res['batch'] = self.data.cells.batch.astype(np.uint16)
batch_number_unique = np.unique(umap_res['batch'])
batch_count = len(batch_number_unique)
cmap = stereo_conf.get_colors(palette, batch_count)
cmap = stereo_conf.get_colors(palette, batch_count, order=batch_number_unique)
fig_all = umap_res.hvplot.scatter(
x='x', y='y', c='batch', cmap=cmap, cnorm='eq_hist',
).opts(
Expand Down Expand Up @@ -771,7 +773,7 @@ def umap(
dot_size: Optional[int] = None,
width: Optional[int] = None,
height: Optional[int] = None,
palette: Optional[int] = None,
palette: Optional[Union[int, list]] = None,
vmin: float = None,
vmax: float = None,
**kwargs
Expand All @@ -788,7 +790,7 @@ def umap(
:param dot_size: the dot size.
:param width: the figure width in pixels.
:param height: the figure height in pixels.
:param palette: color theme.
:param palette: a palette name of a list of colors.
:param out_path: the path to save the figure.
:param out_dpi: the dpi when the figure is saved.
:param vmin: The value representing the lower limit of the color scale. Values smaller than vmin are plotted with the same color as vmin.
Expand Down Expand Up @@ -854,7 +856,8 @@ def cluster_scatter(
x_label: Optional[str] = None,
y_label: Optional[str] = None,
dot_size: Optional[int] = None,
colors: Optional[str] = 'stereo_30',
# colors: Optional[str] = 'stereo_30',
palette: Optional[Union[str, dict, list]] = 'stereo_30',
invert_y: Optional[bool] = True,
hue_order: Optional[set] = None,
width: Optional[int] = None,
Expand All @@ -876,7 +879,8 @@ def cluster_scatter(
:param x_label: the x label.
:param y_label: the y label.
:param dot_size: the dot size.
:param colors: the color list.
:param palette: a palette name, a list of colors whose length at least equal to the groups to be shown or
a dict whose keys are the groups and values are the colors.
:param invert_y: whether to invert y-axis.
:param hue_order: the classification method.
:param width: the figure width in pixels.
Expand All @@ -903,8 +907,10 @@ def cluster_scatter(
""" # noqa
res = self.check_res_key(res_key)
group_list = res['group'].to_numpy(copy=True)
n = np.unique(group_list).size
palette = stereo_conf.get_colors(colors, n=n)
if hue_order is None:
hue_order = natsorted(np.unique(group_list))
n = len(hue_order)
# palette = stereo_conf.get_colors(colors, n=n, order=hue_order)
x = self.data.position[:, 0]
y = self.data.position[:, 1]
x_min, x_max = int(x.min()), int(x.max())
Expand All @@ -926,13 +932,16 @@ def cluster_scatter(
if show_others:
group_list[~isin] = 'others'
n = np.unique(group_list).size
palette = palette[0:n - 1] + [others_color]
# palette = palette[0:n - 1] + [others_color]
hue_order = natsorted(np.unique(group_list[isin])) + ['others']
palette = stereo_conf.get_colors(palette, n=n-1, order=hue_order)
palette.append(others_color)
else:
group_list = group_list[isin]
n = np.unique(group_list).size
palette = palette[0:n]
# palette = palette[0:n]
hue_order = natsorted(np.unique(group_list))
palette = stereo_conf.get_colors(palette, n=n, order=hue_order)
x = x[isin]
y = y[isin]

Expand Down Expand Up @@ -1199,7 +1208,7 @@ def hotspot_modules(
res_key: str = 'spatial_hotspot',
ncols: Optional[int] = 2,
dot_size: Optional[int] = None,
palette: Optional[str] = 'stereo',
palette: Optional[Union[str, list]] = 'stereo',
width: Optional[str] = None,
height: Optional[str] = None,
title: Optional[str] = None,
Expand All @@ -1213,7 +1222,7 @@ def hotspot_modules(
:param res_key: the result key of spatial hotspot.
:param ncols: the number of columns.
:param dot_size: the dot size.
:param palette: Color theme, defaults to `'CET_L4'`.
:param palette: a palette name or a list of colors, defaults to `'stereo'`.
:param width: the figure width in pixels.
:param height: the figure height in pixels.
:param out_path: the path to save the figure.
Expand Down Expand Up @@ -1318,8 +1327,9 @@ def cells_plotting(
:param color_key: the key to get the data to color the plot, it is ignored when the `color_by` is set to 'total_counts' or 'n_genes_by_counts'.
:param bgcolor: set background color.
:param palette: color theme,
when `color_by` is 'cluster', it can be a palette name, a list of colors or a dictionary of colors whose keys is the cluster name,
when other `color_by` is set, it only can be a string of palette name.
when `color_by` is 'cluster', it can be a palette name, a list of colors whose length equal to the groups,
or a dict whose keys are the groups and values are colors,
when other `color_by` is set, it only can be a palette name.
:param width: the figure width in pixels.
:param height: the figure height in pixels.
:param fg_alpha: the transparency of foreground image, between 0 and 1, defaults to 0.5
Expand All @@ -1342,27 +1352,37 @@ def cells_plotting(
:param horizontal_offset_additional: the additional offset between each slice on horizontal direction while reorganizing coordinates.
:param vertical_offset_additional: the additional offset between each slice on vertical direction while reorganizing coordinates.
:return the figure object if `show` is set to False, otherwise, show the figure directly.
.. note::
Exporting
------------------
This plot can be exported as PNG and SVG, then converted to PDF.
This plot can be exported as PNG, SVG or PDF.
You need the following necessary dependencies to support exporting:
conda install -c conda-forge selenium firefox geckodriver cairosvg
On Linux systems, you may need to install some additional libraries to support the above dependencies,
for example, on Ubuntu, you need to install the following libraries:
On Linux, you may need to install some additional libraries to support the above dependencies,
for example, on Ubuntu, the following libraries need to be installed:
sudo apt-get install libgtk-3-dev libasound2-dev
On other linux systems, you may need to install the corresponding libraries according to the error message.
On others Linux, you may need to install the corresponding libraries according to the error message.
There are two ways to export the plot, one is to manupulate on browser,
There are two ways to export the plot, one is to manupulate on browser when you run it on jupyter notebook,
another is to call the method `save_plot <stereo.plots.plot_cells.PlotCells.save_plot.html>`_ of this figure object.
Example code for the second way:
.. code-block:: python
fig = data.plt.cells_plotting(show=False)
fig.show()
fig.save_plot('path/to/save/plot.pdf')
""" # noqa
from .plot_cells import PlotCells
if color_by in ('cluster', 'gene'):
Expand Down
11 changes: 1 addition & 10 deletions stereo/plots/plot_ms_spatial_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def ms_spatial_scatter(
height: Optional[int] = None,
x_label: Optional[Union[list, str]] = 'spatial1',
y_label: Optional[Union[list, str]] = 'spatial2',
title: Optional[str] = None,
vmin: float = None,
vmax: float = None,
marker: str = 'o',
Expand All @@ -61,7 +60,6 @@ def ms_spatial_scatter(
by default, it will be set to 6 times of `nrows`.
:param x_label: the label of x-axis, defaults to 'spatial1'.
:param y_label: the label of y-axis, defaults to 'spatial2'.
:param title: the title of each slice plot, defaults to None.
:param vmin: The value representing the lower limit of the color scale. Values smaller than vmin are plotted with the same color as vmin.
:param vmax: The value representing the higher limit of the color scale. Values greater than vmax are plotted with the same color as vmax.
vmin and vmax will be ignored when `color_by` is 'cluster'.
Expand Down Expand Up @@ -162,13 +160,6 @@ def ms_spatial_scatter(
x = data.position[:, 0]
y = data.position[:, 1]
hue = hue_list[idx]
# if color_by == 'cluster':
# title = color_key
# elif color_by == 'gene':
# title = color_key
# else:
# title = color_by
# ax.set_title(title)
ax.set_title(f'sample {self.ms_data.names[idx]}')
if len(hue) == 0:
continue
Expand All @@ -186,7 +177,7 @@ def ms_spatial_scatter(

return fig

def _get_row_col(self, ncols: int):
def _get_row_col(self, ncols: int = None):
if ncols is None:
ncols = self.__default_ncols
ncols = min(ncols, self.ms_data.num_slice)
Expand Down
17 changes: 12 additions & 5 deletions stereo/plots/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,18 @@ def base_scatter(
from natsort import natsorted
import collections
g = natsorted(set(hue))
if hue_order is not None:
g = hue_order
colors = stereo_conf.get_colors(palette, n=len(g))
color_dict = collections.OrderedDict(dict([(g[i], colors[i]) for i in range(len(g))]))
sns.scatterplot(x=x, y=y, hue=hue, hue_order=g, linewidth=0, marker=marker,
if hue_order is None:
hue_order = g
# if isinstance(palette, (dict, collections.OrderedDict)):
# palette = [palette[i] for i in g if i in palette]
# if len(palette) < len(g):
# colors = stereo_conf.get_colors(palette, n=len(g))
# else:
# colors = palette
# color_dict = collections.OrderedDict(dict([(g[i], colors[i]) for i in range(len(g))]))
colors = stereo_conf.get_colors(palette, n=len(g), order=hue_order)
color_dict = dict(zip(hue_order, colors))
sns.scatterplot(x=x, y=y, hue=hue, hue_order=hue_order, linewidth=0, marker=marker,
palette=color_dict, size=hue, sizes=(dot_size, dot_size), ax=ax, alpha=foreground_alpha)
handles, labels = ax.get_legend_handles_labels()
# ax.legend_.remove()
Expand Down
Loading

0 comments on commit 50632e2

Please sign in to comment.