From 50632e27a316fd88a265276687c6010dee50898d Mon Sep 17 00:00:00 2001 From: tanliwei Date: Fri, 26 Jul 2024 15:41:17 +0800 Subject: [PATCH] update for plot module --- stereo/plots/plot_cells.py | 2 +- stereo/plots/plot_collection.py | 78 +++++++++++------- stereo/plots/plot_ms_spatial_scatter.py | 11 +-- stereo/plots/scatter.py | 17 ++-- stereo/stereo_config.py | 105 +++++++++++++++++------- 5 files changed, 140 insertions(+), 73 deletions(-) diff --git a/stereo/plots/plot_cells.py b/stereo/plots/plot_cells.py index 96035c2f..fa4c92e3 100644 --- a/stereo/plots/plot_cells.py +++ b/stereo/plots/plot_cells.py @@ -358,7 +358,7 @@ def save_plot( :param save_height: the height of the saved plot, defaults to be the same as the plot. :param save_only_in_view: only save the plot in the view, defaults to False. :param with_base_image: whether to save the plot with the base image, defaults to False. - Currently, the resolution of the saved base image may not be high. + Currently, the dpi of the saved base image may not be high. """ self._set_firefox_and_driver_path() diff --git a/stereo/plots/plot_collection.py b/stereo/plots/plot_collection.py index 794ded9c..2e33522f 100644 --- a/stereo/plots/plot_collection.py +++ b/stereo/plots/plot_collection.py @@ -206,7 +206,7 @@ def marker_genes_volcano( group_name: str, res_key: Optional[str] = 'marker_genes', hue_order: Optional[set] = ('down', 'normal', 'up'), - palette: Optional[str] = ("#377EB8", "grey", "#E41A1C"), + palette: Optional[Union[list, tuple]] = ("#377EB8", "grey", "#E41A1C"), alpha: Optional[int] = 1, dot_size: Optional[int] = 15, text_genes: Optional[list] = None, @@ -225,7 +225,8 @@ def marker_genes_volcano( :param group_name: the group name. :param res_key: the result key of marker gene. :param hue_order: the classification method. - :param palette: the color theme. + :param palette: the color theme, a list of colors whose length is 3, + in which, each one respectively specifies the color of 'down', 'normal' and 'up' marker genes. :param alpha: the opacity. :param dot_size: the dot size. :param text_genes: show gene names. @@ -277,7 +278,7 @@ def genes_count( :param y_label: list of y label. :param ncols: the number of columns. :param dot_size: the dot size. - :param palette: color theme. + :param palette: a single color specifying the color of markers. :param width: the figure width in pixels. :param height: the figure height in pixels. :param out_path: the path to save the figure. @@ -334,7 +335,7 @@ def spatial_scatter( cells_key: Optional[list] = ["total_counts", "n_genes_by_counts"], ncols: Optional[int] = 2, dot_size: Optional[int] = None, - palette: Optional[str] = 'stereo', + palette: Optional[Union[str, list]] = 'stereo', width: Optional[int] = None, height: Optional[int] = None, x_label: Optional[Union[list, str]] = 'spatial1', @@ -350,7 +351,7 @@ def spatial_scatter( :param cells_key: specified cells key list. :param ncols: the number of plot columns. :param dot_size: the dot size. - :param palette: the color theme. + :param palette: a palette name or a list of colors. :param width: the figure width in pixels. :param height: the figure height in pixels. :param x_label: list of x label. @@ -406,7 +407,7 @@ def spatial_scatter_by_gene( self, gene_name: Union[str, list, np.ndarray], dot_size: Optional[int] = None, - palette: Optional[str] = 'CET_L4', + palette: Optional[Union[str, list]] = 'CET_L4', color_bar_reverse: Optional[bool] = True, width: Optional[int] = None, height: Optional[int] = None, @@ -421,7 +422,7 @@ def spatial_scatter_by_gene( :param gene_name: a gene or a list of genes you want to show. :param dot_size: the dot size, defaults to `None`. - :param palette: the color theme, defaults to `'CET_L4'`. + :param palette: a palette name or a list of colors, defaults to `'CET_L4'`. :param color_bar_reverse: if True, reverse the color bar, defaults to False :param width: the figure width in pixels. :param height: the figure height in pixels. @@ -480,7 +481,7 @@ def gaussian_smooth_scatter_by_gene( self, gene_name: str = None, dot_size: Optional[int] = None, - palette: Optional[str] = 'CET_L4', + palette: Optional[Union[str, list]] = 'CET_L4', color_bar_reverse: Optional[bool] = True, width: Optional[int] = None, height: Optional[int] = None, @@ -496,7 +497,7 @@ def gaussian_smooth_scatter_by_gene( :param gene_name: specify the gene you want to draw, if `None` by default, will select randomly. :param dot_size: marker sizemarker size, defaults to `None`. - :param palette: Color theme, defaults to `'CET_L4'`. + :param palette: a palette name or a list of colors, defaults to `'CET_L4'`. :param color_bar_reverse: if True, reverse the color bar, defaults to False :param width: the figure width in pixels. :param height: the figure height in pixels. @@ -668,7 +669,7 @@ def batches_umap( y_label: Optional[str] = 'umap2', bfig_title: Optional[str] = 'all batches', dot_size: Optional[int] = 1, - palette: Optional[Union[str, list]] = 'stereo_30', + palette: Optional[Union[str, list, dict]] = 'stereo_30', width: Optional[int] = None, height: Optional[int] = None ): @@ -681,7 +682,8 @@ def batches_umap( :param y_label: the y label. :param bfig_title: the big figure title. :param dot_size: the dot size. - :param palette: the color list. + :param palette: a palette name, a list of colors whose length is equal to the batches, + or a dict whose keys are batch numbers and values are colors. :param width: the figure width in pixels. :param height: the figure height in pixels. @@ -707,7 +709,7 @@ def batches_umap( umap_res['batch'] = self.data.cells.batch.astype(np.uint16) batch_number_unique = np.unique(umap_res['batch']) batch_count = len(batch_number_unique) - cmap = stereo_conf.get_colors(palette, batch_count) + cmap = stereo_conf.get_colors(palette, batch_count, order=batch_number_unique) fig_all = umap_res.hvplot.scatter( x='x', y='y', c='batch', cmap=cmap, cnorm='eq_hist', ).opts( @@ -771,7 +773,7 @@ def umap( dot_size: Optional[int] = None, width: Optional[int] = None, height: Optional[int] = None, - palette: Optional[int] = None, + palette: Optional[Union[int, list]] = None, vmin: float = None, vmax: float = None, **kwargs @@ -788,7 +790,7 @@ def umap( :param dot_size: the dot size. :param width: the figure width in pixels. :param height: the figure height in pixels. - :param palette: color theme. + :param palette: a palette name of a list of colors. :param out_path: the path to save the figure. :param out_dpi: the dpi when the figure is saved. :param vmin: The value representing the lower limit of the color scale. Values smaller than vmin are plotted with the same color as vmin. @@ -854,7 +856,8 @@ def cluster_scatter( x_label: Optional[str] = None, y_label: Optional[str] = None, dot_size: Optional[int] = None, - colors: Optional[str] = 'stereo_30', + # colors: Optional[str] = 'stereo_30', + palette: Optional[Union[str, dict, list]] = 'stereo_30', invert_y: Optional[bool] = True, hue_order: Optional[set] = None, width: Optional[int] = None, @@ -876,7 +879,8 @@ def cluster_scatter( :param x_label: the x label. :param y_label: the y label. :param dot_size: the dot size. - :param colors: the color list. + :param palette: a palette name, a list of colors whose length at least equal to the groups to be shown or + a dict whose keys are the groups and values are the colors. :param invert_y: whether to invert y-axis. :param hue_order: the classification method. :param width: the figure width in pixels. @@ -903,8 +907,10 @@ def cluster_scatter( """ # noqa res = self.check_res_key(res_key) group_list = res['group'].to_numpy(copy=True) - n = np.unique(group_list).size - palette = stereo_conf.get_colors(colors, n=n) + if hue_order is None: + hue_order = natsorted(np.unique(group_list)) + n = len(hue_order) + # palette = stereo_conf.get_colors(colors, n=n, order=hue_order) x = self.data.position[:, 0] y = self.data.position[:, 1] x_min, x_max = int(x.min()), int(x.max()) @@ -926,13 +932,16 @@ def cluster_scatter( if show_others: group_list[~isin] = 'others' n = np.unique(group_list).size - palette = palette[0:n - 1] + [others_color] + # palette = palette[0:n - 1] + [others_color] hue_order = natsorted(np.unique(group_list[isin])) + ['others'] + palette = stereo_conf.get_colors(palette, n=n-1, order=hue_order) + palette.append(others_color) else: group_list = group_list[isin] n = np.unique(group_list).size - palette = palette[0:n] + # palette = palette[0:n] hue_order = natsorted(np.unique(group_list)) + palette = stereo_conf.get_colors(palette, n=n, order=hue_order) x = x[isin] y = y[isin] @@ -1199,7 +1208,7 @@ def hotspot_modules( res_key: str = 'spatial_hotspot', ncols: Optional[int] = 2, dot_size: Optional[int] = None, - palette: Optional[str] = 'stereo', + palette: Optional[Union[str, list]] = 'stereo', width: Optional[str] = None, height: Optional[str] = None, title: Optional[str] = None, @@ -1213,7 +1222,7 @@ def hotspot_modules( :param res_key: the result key of spatial hotspot. :param ncols: the number of columns. :param dot_size: the dot size. - :param palette: Color theme, defaults to `'CET_L4'`. + :param palette: a palette name or a list of colors, defaults to `'stereo'`. :param width: the figure width in pixels. :param height: the figure height in pixels. :param out_path: the path to save the figure. @@ -1318,8 +1327,9 @@ def cells_plotting( :param color_key: the key to get the data to color the plot, it is ignored when the `color_by` is set to 'total_counts' or 'n_genes_by_counts'. :param bgcolor: set background color. :param palette: color theme, - when `color_by` is 'cluster', it can be a palette name, a list of colors or a dictionary of colors whose keys is the cluster name, - when other `color_by` is set, it only can be a string of palette name. + when `color_by` is 'cluster', it can be a palette name, a list of colors whose length equal to the groups, + or a dict whose keys are the groups and values are colors, + when other `color_by` is set, it only can be a palette name. :param width: the figure width in pixels. :param height: the figure height in pixels. :param fg_alpha: the transparency of foreground image, between 0 and 1, defaults to 0.5 @@ -1342,27 +1352,37 @@ def cells_plotting( :param horizontal_offset_additional: the additional offset between each slice on horizontal direction while reorganizing coordinates. :param vertical_offset_additional: the additional offset between each slice on vertical direction while reorganizing coordinates. + :return the figure object if `show` is set to False, otherwise, show the figure directly. + .. note:: Exporting ------------------ - This plot can be exported as PNG and SVG, then converted to PDF. + This plot can be exported as PNG, SVG or PDF. You need the following necessary dependencies to support exporting: conda install -c conda-forge selenium firefox geckodriver cairosvg - On Linux systems, you may need to install some additional libraries to support the above dependencies, - for example, on Ubuntu, you need to install the following libraries: + On Linux, you may need to install some additional libraries to support the above dependencies, + for example, on Ubuntu, the following libraries need to be installed: sudo apt-get install libgtk-3-dev libasound2-dev - On other linux systems, you may need to install the corresponding libraries according to the error message. + On others Linux, you may need to install the corresponding libraries according to the error message. - There are two ways to export the plot, one is to manupulate on browser, + There are two ways to export the plot, one is to manupulate on browser when you run it on jupyter notebook, another is to call the method `save_plot `_ of this figure object. + Example code for the second way: + + .. code-block:: python + + fig = data.plt.cells_plotting(show=False) + fig.show() + fig.save_plot('path/to/save/plot.pdf') + """ # noqa from .plot_cells import PlotCells if color_by in ('cluster', 'gene'): diff --git a/stereo/plots/plot_ms_spatial_scatter.py b/stereo/plots/plot_ms_spatial_scatter.py index 6b848a8a..036c0134 100644 --- a/stereo/plots/plot_ms_spatial_scatter.py +++ b/stereo/plots/plot_ms_spatial_scatter.py @@ -34,7 +34,6 @@ def ms_spatial_scatter( height: Optional[int] = None, x_label: Optional[Union[list, str]] = 'spatial1', y_label: Optional[Union[list, str]] = 'spatial2', - title: Optional[str] = None, vmin: float = None, vmax: float = None, marker: str = 'o', @@ -61,7 +60,6 @@ def ms_spatial_scatter( by default, it will be set to 6 times of `nrows`. :param x_label: the label of x-axis, defaults to 'spatial1'. :param y_label: the label of y-axis, defaults to 'spatial2'. - :param title: the title of each slice plot, defaults to None. :param vmin: The value representing the lower limit of the color scale. Values smaller than vmin are plotted with the same color as vmin. :param vmax: The value representing the higher limit of the color scale. Values greater than vmax are plotted with the same color as vmax. vmin and vmax will be ignored when `color_by` is 'cluster'. @@ -162,13 +160,6 @@ def ms_spatial_scatter( x = data.position[:, 0] y = data.position[:, 1] hue = hue_list[idx] - # if color_by == 'cluster': - # title = color_key - # elif color_by == 'gene': - # title = color_key - # else: - # title = color_by - # ax.set_title(title) ax.set_title(f'sample {self.ms_data.names[idx]}') if len(hue) == 0: continue @@ -186,7 +177,7 @@ def ms_spatial_scatter( return fig - def _get_row_col(self, ncols: int): + def _get_row_col(self, ncols: int = None): if ncols is None: ncols = self.__default_ncols ncols = min(ncols, self.ms_data.num_slice) diff --git a/stereo/plots/scatter.py b/stereo/plots/scatter.py index 9417653e..28c1c6b1 100644 --- a/stereo/plots/scatter.py +++ b/stereo/plots/scatter.py @@ -271,11 +271,18 @@ def base_scatter( from natsort import natsorted import collections g = natsorted(set(hue)) - if hue_order is not None: - g = hue_order - colors = stereo_conf.get_colors(palette, n=len(g)) - color_dict = collections.OrderedDict(dict([(g[i], colors[i]) for i in range(len(g))])) - sns.scatterplot(x=x, y=y, hue=hue, hue_order=g, linewidth=0, marker=marker, + if hue_order is None: + hue_order = g + # if isinstance(palette, (dict, collections.OrderedDict)): + # palette = [palette[i] for i in g if i in palette] + # if len(palette) < len(g): + # colors = stereo_conf.get_colors(palette, n=len(g)) + # else: + # colors = palette + # color_dict = collections.OrderedDict(dict([(g[i], colors[i]) for i in range(len(g))])) + colors = stereo_conf.get_colors(palette, n=len(g), order=hue_order) + color_dict = dict(zip(hue_order, colors)) + sns.scatterplot(x=x, y=y, hue=hue, hue_order=hue_order, linewidth=0, marker=marker, palette=color_dict, size=hue, sizes=(dot_size, dot_size), ax=ax, alpha=foreground_alpha) handles, labels = ax.get_legend_handles_labels() # ax.legend_.remove() diff --git a/stereo/stereo_config.py b/stereo/stereo_config.py index 6e862577..67330d0d 100644 --- a/stereo/stereo_config.py +++ b/stereo/stereo_config.py @@ -10,12 +10,14 @@ from pathlib import Path from typing import Optional from typing import Union +from collections import OrderedDict from copy import deepcopy import matplotlib.colors as mpl_colors from colorcet import palette, aliases, cetnames_flipped from matplotlib import rcParams from matplotlib import rcParamsDefault +import numpy as np class StereoConfig(object): @@ -46,26 +48,40 @@ def __init__( self._log_format = log_format self.out_dir = output self.data_dir = data_dir if data_dir else os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data') + self._palette_custom = None + + + @property + def palette_custom(self): + return self._palette_custom + + @palette_custom.setter + def palette_custom(self, palette_custom): + if not isinstance(palette_custom, list): + raise ValueError('palette_custom should be a list of colors') + self._palette_custom = palette_custom @property def colormaps(self): - # colormaps = deepcopy(palette) - colormaps = {k: v for k, v in palette.items() if 'glasbey' in k and '_bw_' not in k} - colormaps['stereo_30'] = ["#E41A1C", "#377EB8", "#4DAF4A", "#984EA3", "#FF7F00", "#A65628", "#FFFF33", - "#F781BF", "#999999", "#E5D8BD", "#B3CDE3", "#CCEBC5", "#FED9A6", "#FBB4AE", - "#8DD3C7", "#BEBADA", "#80B1D3", "#B3DE69", "#FCCDE5", "#BC80BD", "#FFED6F", - "#8DA0CB", "#E78AC3", "#E5C494", "#CCCCCC", "#FB9A99", "#E31A1C", "#CAB2D6", - "#6A3D9A", "#B15928"] + color_keys = sorted([k for k in palette.keys() if 'glasbey' in k and '_bw_' not in k]) + colormaps = OrderedDict([(k, palette[k]) for k in color_keys]) + # colormaps = {k: v for k, v in palette.items() if 'glasbey' in k and '_bw_' not in k} + colormaps['stereo_30'] = [ + "#E41A1C", "#377EB8", "#4DAF4A", "#984EA3", "#FF7F00", "#A65628", + "#FFFF33", "#F781BF", "#999999", "#E5D8BD", "#B3CDE3", "#CCEBC5", + "#FED9A6", "#FBB4AE", "#8DD3C7", "#BEBADA", "#80B1D3", "#B3DE69", + "#FCCDE5", "#BC80BD", "#FFED6F", "#8DA0CB", "#E78AC3", "#E5C494", + "#CCCCCC", "#FB9A99", "#E31A1C", "#CAB2D6", "#6A3D9A", "#B15928" + ] + if self.palette_custom is not None: + colormaps['custom'] = self.palette_custom return colormaps @property def linear_colormaps(self): - # colormaps = deepcopy(palette) - colormaps = {} - for k, v in palette.items(): - if 'glasbey' in k or k in aliases or k in cetnames_flipped: - continue - colormaps[k] = v + color_keys = sorted([k for k in palette.keys() if not ('glasbey' in k or k in aliases or k in cetnames_flipped)]) + colormaps = OrderedDict([(k, palette[k]) for k in color_keys]) + stmap_colors = ['#0c3383', '#0a88ba', '#f2d338', '#f28f38', '#d91e1e'] nodes = [0.0, 0.25, 0.50, 0.75, 1.0] mycmap = mpl_colors.LinearSegmentedColormap.from_list("mycmap", list(zip(nodes, stmap_colors))) @@ -75,30 +91,63 @@ def linear_colormaps(self): def linear_colors(self, colors, reverse=False): if isinstance(colors, str): - if colors not in self.linear_colormaps: + linear_colormaps = deepcopy(palette) + linear_colormaps.update(self.linear_colormaps) + if colors not in linear_colormaps: raise ValueError(f'{colors} not in colormaps, color value range in {self.linear_colormaps.keys()}') else: - return self.linear_colormaps[colors][::-1] if reverse else self.linear_colormaps[colors] - elif isinstance(colors, list): - return colors + return linear_colormaps[colors][::-1] if reverse else linear_colormaps[colors] + elif isinstance(colors, (list, tuple, np.ndarray)): + mycmap = mpl_colors.LinearSegmentedColormap.from_list("mycmap", colors) + colors = [mpl_colors.rgb2hex(mycmap(i)) for i in range(mycmap.N)] + return colors[::-1] if reverse else colors else: raise ValueError('colors should be str or list type') - def get_colors(self, colors, n=None): + # def get_colors(self, colors, n=None): + # if isinstance(colors, str): + # colormaps = deepcopy(palette) + # colormaps.update(self.colormaps) + # if colors not in colormaps: + # raise ValueError(f'{colors} not in colormaps, color value range in {self.colormaps.keys()}') + # if n is not None: + # if n > len(colormaps[colors]): + # mycmap = mpl_colors.LinearSegmentedColormap.from_list("mycmap", colormaps[colors], N=n) + # color_list = [mpl_colors.rgb2hex(mycmap(i)) for i in range(n)] + # else: + # color_list = colormaps[colors][0: n] + # return color_list + # else: + # return colormaps[colors] + # else: + # return colors + + def get_colors(self, colors, n=None, order=None): if isinstance(colors, str): - if colors not in self.colormaps: + colormaps = deepcopy(palette) + colormaps.update(self.colormaps) + if colors not in colormaps: raise ValueError(f'{colors} not in colormaps, color value range in {self.colormaps.keys()}') - if n is not None: - if n > len(self.colormaps[colors]): - mycmap = mpl_colors.LinearSegmentedColormap.from_list("mycmap", self.colormaps[colors], N=n) - color_list = [mpl_colors.rgb2hex(mycmap(i)) for i in range(n)] - else: - color_list = self.colormaps[colors][0: n] - return color_list + + colormaps_selected = colormaps[colors] + elif isinstance(colors, (dict, OrderedDict)): + if order is not None: + colormaps_selected = [colors[k] for k in order if k in colors] else: - return self.colormaps[colors] + colormaps_selected = list(colors.values()) + elif isinstance(colors, (list, tuple, np.ndarray)): + colormaps_selected = list(colors) else: - return colors + raise ValueError('colors should be str, dict, list, tuple or np.ndarray type') + + if n is not None: + if n > len(colormaps_selected): + mycmap = mpl_colors.LinearSegmentedColormap.from_list("mycmap", colormaps_selected, N=n) + colormaps_selected = [mpl_colors.rgb2hex(mycmap(i)) for i in range(n)] + else: + colormaps_selected = colormaps_selected[0: n] + + return colormaps_selected @property def log_file(self) -> Union[str, Path, None]: