diff --git a/stereo/plots/plot_cells.py b/stereo/plots/plot_cells.py index 12e74fb6..96035c2f 100644 --- a/stereo/plots/plot_cells.py +++ b/stereo/plots/plot_cells.py @@ -13,23 +13,27 @@ from natsort import natsorted from stereo.stereo_config import stereo_conf +from stereo.core.stereo_exp_data import StereoExpData +from stereo.preprocess.qc import cal_total_counts, cal_n_genes_by_counts, cal_pct_counts_mt class PlotCells: def __init__( self, data, - color_by='total_count', + color_by='total_counts', color_key=None, # cluster_res_key='cluster', bgcolor='#2F2F4F', + palette=None, width=None, height=None, fg_alpha=0.5, base_image=None, - base_im_to_gray=False + base_im_to_gray=False, + use_raw=True ): - self.data = data + self.data: StereoExpData = data # if cluster_res_key in self.data.tl.result: # res = self.data.tl.result[cluster_res_key] # self.cluster_res = np.array(res['group']) @@ -41,29 +45,42 @@ def __init__( # self.cluster_res = None # self.cluster_id = [] if color_by != 'cluster': + self.palette = 'stereo' if palette is None else palette + assert isinstance(self.palette, str), f'The palette must be a name of palette when color_by is {color_by}' self.cluster_res = None self.cluster_id = [] if color_by == 'gene': if not np.any(np.isin(self.data.genes.gene_name, color_key)): raise ValueError(f'The gene {color_key} is not found.') else: + self.palette = 'stereo_30' if palette is None else 'custom' if color_key in self.data.tl.result: res = self.data.tl.result[color_key] self.cluster_res = np.array(res['group']) self.cluster_id = natsorted(np.unique(self.cluster_res).tolist()) n = len(self.cluster_id) - cmap = stereo_conf.get_colors('stereo_30', n) + # if isinstance(palette, str): + # cmap = stereo_conf.get_colors(self.palette, n) + # self.cluster_color_map = OrderedDict({k: v for k, v in zip(self.cluster_id, cmap)}) + if isinstance(palette, (list, np.ndarray)): + stereo_conf.palette_custom = list(palette) + elif isinstance(palette, dict): + stereo_conf.palette_custom = [palette[k] for k in self.cluster_id if k in palette] + cmap = stereo_conf.get_colors(self.palette, n) self.cluster_color_map = OrderedDict({k: v for k, v in zip(self.cluster_id, cmap)}) else: self.cluster_res = None self.cluster_id = [] + self.last_cm_key_continuous = None + self.last_cm_key_discrete = None self.color_by_input = color_by self.color_key = color_key self.bgcolor = bgcolor self.width, self.height = self._set_width_and_height(width, height) self.fg_alpha = fg_alpha self.base_image = base_image + self.base_image_points = None if self.fg_alpha < 0: self.fg_alpha = 0.3 @@ -76,7 +93,15 @@ def __init__( self.hover_fg_alpha = self.fg_alpha / 2 self.figure_polygons = None self.figure_points = None + self.figure_colorbar_legend = None + self.colorbar_or_legend = None self.base_im_to_gray = base_im_to_gray + self.use_raw = use_raw and self.data.raw is not None + self.rangexy_stream = None + self.x_range = None + self.y_range = None + self.firefox_path = None + self.driver_path = None def _set_width_and_height(self, width, height): if width is None or height is None: @@ -115,14 +140,11 @@ def _create_base_image_xarray(self): image_xarray = None with tiff.TiffFile(self.base_image) as tif: - # image_data = tiff.imread(self.base_image) image_data = tif.asarray() if len(image_data.shape) == 3 and self.base_im_to_gray: from cv2 import cvtColor, COLOR_BGR2GRAY image_data = cvtColor(image_data[:, :, [2, 1, 0]], COLOR_BGR2GRAY) if len(image_data.shape) == 3 and image_data.dtype == np.uint16: - # from stereo.image.tissue_cut.tissue_cut_utils.tissue_seg_utils import transfer_16bit_to_8bit - # image_data = transfer_16bit_to_8bit(image_data) from matplotlib.colors import Normalize if tif.shaped_metadata is not None: metadata = tif.shaped_metadata[0] @@ -160,22 +182,41 @@ def _create_polygons(self, color_by): polygons = [] color = [] position = [] + if self.use_raw: + if self.data.shape != self.data.raw.shape: + cells_isin = np.isin(self.data.raw.cell_names, self.data.cell_names) + genes_isin = np.isin(self.data.raw.gene_names, self.data.gene_names) + exp_matrix = self.data.raw.exp_matrix[cells_isin][:, genes_isin] + gene_names = self.data.raw.gene_names[genes_isin] + else: + exp_matrix = self.data.raw.exp_matrix + gene_names = self.data.raw.gene_names + else: + exp_matrix = self.data.exp_matrix + gene_names = self.data.gene_names + total_counts = cal_total_counts(exp_matrix) + n_genes_by_counts = cal_n_genes_by_counts(exp_matrix) + pct_counts_mt = cal_pct_counts_mt(exp_matrix, gene_names) if color_by == 'gene': in_bool = np.isin(self.data.genes.gene_name, self.color_key) for i, cell_border in enumerate(self.data.cells.cell_border): cell_border = cell_border[cell_border[:, 0] < 32767] + self.data.position[i] cell_border = cell_border.reshape((-1,)).tolist() polygons.append([cell_border]) - if color_by == 'total_count': - color.append(self.data.cells.total_counts[i]) + if color_by == 'total_counts': + # color.append(self.data.cells.total_counts[i]) + color.append(total_counts[i]) elif color_by == 'n_genes_by_counts': - color.append(self.data.cells.n_genes_by_counts[i]) + # color.append(self.data.cells.n_genes_by_counts[i]) + color.append(n_genes_by_counts[i]) elif color_by == 'gene': - color.append(self.data.exp_matrix[i, in_bool].sum()) + color.append(exp_matrix[i, in_bool].sum()) elif color_by == 'cluster': - color.append(self.cluster_res[i] if self.cluster_res is not None else self.data.cells.total_counts[i]) + # color.append(self.cluster_res[i] if self.cluster_res is not None else self.data.cells.total_counts[i]) + color.append(self.cluster_res[i] if self.cluster_res is not None else total_counts[i]) else: - color.append(self.data.cells.total_counts[i]) + # color.append(self.data.cells.total_counts[i]) + color.append(total_counts[i]) position.append(str(tuple(self.data.position[i].astype(np.uint32)))) polygons = spd.geometry.PolygonArray(polygons) @@ -183,9 +224,12 @@ def _create_polygons(self, color_by): 'polygons': polygons, 'color': color, 'position': position, - 'total_counts': self.data.cells.total_counts.astype(np.uint32), - 'pct_counts_mt': self.data.cells.pct_counts_mt, - 'n_genes_by_counts': self.data.cells.n_genes_by_counts.astype(np.uint32), + # 'total_counts': self.data.cells.total_counts.astype(np.uint32), + # 'pct_counts_mt': self.data.cells.pct_counts_mt, + # 'n_genes_by_counts': self.data.cells.n_genes_by_counts.astype(np.uint32), + 'total_counts': total_counts, + 'pct_counts_mt': pct_counts_mt, + 'n_genes_by_counts': n_genes_by_counts, 'cluster_id': np.zeros_like(self.data.cell_names) if self.cluster_res is None else self.cluster_res }) @@ -203,11 +247,17 @@ def _create_polygons(self, color_by): return polygons_detail, hover_tool, vdims def _create_widgets(self): + # self.color_map_key_continuous = pn.widgets.Select( + # value='stereo', options=list(stereo_conf.linear_colormaps.keys()), name='color theme', width=200 + # ) + # self.color_map_key_discrete = pn.widgets.Select( + # value='stereo_30', options=list(stereo_conf.colormaps.keys()), name='color theme', width=200 + # ) self.color_map_key_continuous = pn.widgets.Select( - value='stereo', options=list(stereo_conf.linear_colormaps.keys()), name='color theme', width=200 + value=self.palette, options=list(stereo_conf.linear_colormaps.keys()), name='color theme', width=200 ) self.color_map_key_discrete = pn.widgets.Select( - value='stereo_30', options=list(stereo_conf.colormaps.keys()), name='color theme', width=200 + value=self.palette, options=list(stereo_conf.colormaps.keys()), name='color theme', width=200 ) if self.cluster_res is None: self.color_map_key_discrete.visible = False @@ -233,25 +283,328 @@ def _create_widgets(self): self.cluster_colorpicker = pn.widgets.ColorPicker( name='cluster color', width=70, disabled=True, visible=False ) - color_by_key.extend(['total_count', 'n_genes_by_counts', 'gene']) - # self.color_by = pn.widgets.Select(name='color by', options=color_by_key, value=color_by_key[0], width=200) + color_by_key.extend(['total_counts', 'n_genes_by_counts', 'gene']) self.color_by = pn.widgets.Select(name='color by', options=color_by_key, value=self.color_by_input, width=200) - # if self.color_by_input == 'gene': - # gene_names_selector_value = [self.color_key] if isinstance(self.color_key, str) else self.color_key - # else: - # i = np.argmax(self.data.genes.n_counts) - # gene_names_selector_value = [self.data.genes.gene_name[i]] - # if isinstance(gene_names_selector_value, np.ndarray): - # gene_names_selector_value = gene_names_selector_value.tolist() - # self.gene_names = pn.widgets.MultiSelect(name='gene names', value=gene_names_selector_value, options=self.data.genes.gene_name.tolist(), size=10, width=200) if self.color_by_input == 'gene': gene_names_selector_value = self.color_key else: i = np.argmax(self.data.genes.n_counts) gene_names_selector_value = self.data.genes.gene_name[i] - self.gene_names = pn.widgets.Select(name='gene names', value=gene_names_selector_value, - options=self.data.genes.gene_name.tolist(), size=10, width=200) + gene_idx_sorted = np.argsort(self.data.genes.n_counts * -1) + gene_names_sorted = self.data.genes.gene_name[gene_idx_sorted].tolist() + self.gene_names = pn.widgets.Select(name='genes', value=gene_names_selector_value, + options=gene_names_sorted, size=10, width=200) + + self.save_title = pn.widgets.StaticText(name='', value='Save Plot', width=200) + self.save_file_name = pn.widgets.TextInput(name='file name(.png, .svg or .pdf)', width=200) + self.save_file_width = pn.widgets.IntInput(name='width', value=self.width, width=95) + self.save_file_hight = pn.widgets.IntInput(name='height', value=self.height, width=95) + self.save_button = pn.widgets.Button(name='save', button_type="primary", width=100) + self.save_only_in_view = pn.widgets.Checkbox(name='only in view', value=False, width=100) + self.with_base_image = pn.widgets.Checkbox(name='with base image', value=False, width=100) + self.save_button.on_click(self._save_button_callback) + self.save_message = pn.widgets.StaticText(name='', value='', width=400) + + def _set_firefox_and_driver_path(self): + from sys import executable + from os import environ + import platform + + os_type = platform.system().lower() + executable_dir = os.path.dirname(executable) + environ_path = environ['PATH'].split(os.pathsep) + if executable_dir not in environ_path: + environ_path = [executable_dir] + environ_path + if os_type == 'windows': + bin_paths = [ + executable_dir, + os.path.join(executable_dir, 'Scripts'), + os.path.join(executable_dir, 'Library', 'bin'), + os.path.join(executable_dir, 'Library', 'mingw-w64', 'bin'), + os.path.join(executable_dir, 'Library', 'usr', 'bin'), + os.path.join(executable_dir, 'bin') + ] + for bin_path in bin_paths: + if bin_path not in environ_path: + environ_path = [bin_path] + environ_path + a_path = os.path.join(bin_path, 'firefox.exe') + if os.path.exists(a_path) and self.firefox_path is None: + self.firefox_path = a_path + a_path = os.path.join(bin_path, 'geckodriver.exe') + if os.path.exists(a_path) and self.driver_path is None: + self.driver_path = a_path + elif os_type == 'linux': + if self.firefox_path is None: + self.firefox_path = os.path.join(executable_dir, 'firefox') + if self.driver_path is None: + self.driver_path = os.path.join(executable_dir, 'geckodriver') + else: + raise ValueError(f'The operating system {os_type} is not supported.') + environ['PATH'] = os.pathsep.join(environ_path) + + def save_plot( + self, + save_file_name: str, + save_width: int = None, + save_height: int = None, + save_only_in_view: bool = False, + with_base_image: bool = False + ): + """ + Save the plot to a PNG, SVG or PDF file depending on the extension of the file name. + + :param save_file_name: the name of the file to save the plot. + :param save_width: the width of the saved plot, defaults to be the same as the plot. + :param save_height: the height of the saved plot, defaults to be the same as the plot. + :param save_only_in_view: only save the plot in the view, defaults to False. + :param with_base_image: whether to save the plot with the base image, defaults to False. + Currently, the resolution of the saved base image may not be high. + + """ + self._set_firefox_and_driver_path() + + from selenium import webdriver + from selenium.webdriver import FirefoxOptions, FirefoxService + from bokeh.io import export_png, export_svg + from bokeh.io.export import get_svg + from bokeh.layouts import row, Row + from cairosvg import svg2pdf + + if save_file_name == '': + raise ValueError('Please input the file name.') + + if not save_file_name.lower().endswith('.png') and \ + not save_file_name.lower().endswith('.svg') and \ + not save_file_name.lower().endswith('.pdf'): + raise ValueError('Only PNG, SVG and PDF files are supported.') + + save_width = self.width if save_width is None else save_width + save_height = self.height if save_height is None else save_height + if with_base_image: + figure_points = self._create_base_image_figure() + else: + figure_points = None + with_base_image = with_base_image and figure_points is not None + + file_name_prefix, file_name_suffix = os.path.splitext(save_file_name) + save_file_name = f"{file_name_prefix}_cells_plotting{file_name_suffix}" + + current_x_range = self.x_range + current_y_range = self.y_range + try: + if save_only_in_view: + if self.x_range is not None and self.y_range is not None: + figure_polygons = self.figure_polygons[self.x_range[0]:self.x_range[1], self.y_range[0]:self.y_range[1]] + if with_base_image: + figure_points = figure_points.select(x=self.x_range, y=self.y_range) + else: + figure_polygons = self.figure_polygons + else: + figure_polygons = self.figure_polygons + + if not with_base_image: + output_render = hv.render(figure_polygons, backend='bokeh') + for renderer in output_render.renderers: + renderer.glyph.fill_alpha = 1 + else: + output_render = hv.render(figure_points, backend='bokeh') + polygons_render = hv.render(figure_polygons, backend='bokeh') + output_render.renderers.extend(polygons_render.renderers) + output_render.output_backend = 'svg' + output_render.toolbar_location = None + output_render.border_fill_color = None + output_render.outline_line_color = 'gray' + output_render.xaxis.visible = False + output_render.yaxis.visible = False + output_render.width = save_width + output_render.height = save_height + + if save_height == self.height: + figure_colorbar_legend = self.figure_colorbar_legend + else: + figure_colorbar_legend = self._create_colorbar_or_legend( + self.colorbar_or_legend, + self.figure_polygons.opts['cmap'], + self.figure_polygons.data, + save_height + ) + if isinstance(figure_colorbar_legend, Row): + for f in figure_colorbar_legend.children: + f.output_backend = 'svg' + f.height = save_height if save_height > 500 else 500 + else: + figure_colorbar_legend.output_backend = 'svg' + figure_colorbar_legend.height = save_height if save_height > 500 else 500 + + to_save_instance = row(output_render, figure_colorbar_legend) + + opts = FirefoxOptions() + opts.add_argument("--headless") + opts.binary_location = self.firefox_path + service = FirefoxService(executable_path=self.driver_path) + with webdriver.Firefox(options=opts, service=service) as driver: + if save_file_name.lower().endswith('png'): + export_png(to_save_instance, filename=save_file_name, webdriver=driver, timeout=86400) + elif save_file_name.lower().endswith('svg'): + export_svg(to_save_instance, filename=save_file_name, webdriver=driver, timeout=86400) + elif save_file_name.lower().endswith('pdf'): + svg = get_svg(to_save_instance, driver=driver, timeout=86400)[0] + svg2pdf(bytestring=svg, write_to=save_file_name) + except Exception as e: + raise e + finally: + self.rangexy_stream.event(x_range=current_x_range, y_range=current_y_range) + if figure_colorbar_legend is self.figure_colorbar_legend: + if isinstance(figure_colorbar_legend, Row): + for f in figure_colorbar_legend.children: + f.height = self.height if self.height > 500 else 500 + else: + figure_colorbar_legend.height = self.height if self.height > 500 else 500 + return save_file_name + + + def _save_button_callback(self, _): + """ + apt-get install libgtk-3-dev libasound2-dev + conda install -c conda-forge selenium firefox geckodriver cairosvg + """ + self.save_button.loading = True + self.save_message.value = '' + try: + save_file_name = self.save_plot( + self.save_file_name.value, + self.save_file_width.value, + self.save_file_hight.value, + save_only_in_view=self.save_only_in_view.value, + with_base_image=self.with_base_image.value, + ) + self.save_message.value = f'The plot has been saved to {save_file_name}.' + except ValueError as e: + self.save_message.value = f'{str(e)}' + except Exception as e: + raise e + finally: + self.save_button.loading = False + + + def _create_colorbar_or_legend(self, type, cmap, plot_data=None, figure_height=None): + + from bokeh.plotting import figure as bokeh_figure + from bokeh.models import ( + Legend, LegendItem, ColorBar, + EqHistColorMapper, BinnedTicker, + ColumnDataSource + ) + from bokeh.layouts import row + + figure_height = self.height if figure_height is None else figure_height + figure_height = 500 if figure_height < 500 else figure_height + if type == 'colorbar': + figures = 1 + else: + legend_items_in_col = int(figure_height / 25) + legend_counts = len(cmap.keys()) + legend_cols, legend_left = divmod(legend_counts, legend_items_in_col) + if legend_left > 0: + legend_cols += 1 + figures = legend_cols + + figure_list = [] + for i in range(figures): + f = bokeh_figure(width=80, height=figure_height, toolbar_location=None, x_axis_type=None, y_axis_type=None) + f.outline_line_color = None + f.xgrid.grid_line_color = None + f.ygrid.grid_line_color = None + figure_list.append(f) + + if type == 'colorbar': + min_value = min(plot_data['color']) + max_value = max(plot_data['color']) + ticks_num = 100 + ticks_interval = (max_value - min_value) / (ticks_num - 1) + ticks = [min(min_value + i * ticks_interval, max_value) for i in range(ticks_num)] + color_mapper = EqHistColorMapper(palette=cmap, low=min_value, high=max_value) + + data_source = ColumnDataSource(data={ + 'x': [0] * len(ticks), + 'y': [0] * len(ticks), + 'color': ticks + }) + figure_list[0].circle(x='x', y='y', color={'field': 'color', 'transform': color_mapper}, size=0, source=data_source) + ticker = BinnedTicker(mapper=color_mapper, num_major_ticks=8) + color_bar = ColorBar(color_mapper=color_mapper, location='center_left' if self.height >= 500 else 'top_left', + orientation='vertical', height=int(figure_height // 1.5), width=int(figure_height / 500 * 20), + major_tick_line_color='black', major_label_text_font_size=f'{int(figure_height / 500 * 11)}px', + major_tick_in=int(figure_height / 500 * 5), major_tick_line_width=int(figure_height / 500 * 1), + ticker=ticker) + figure_list[0].width = int(figure_height / 500 * 150) + figure_list[0].add_layout(color_bar, 'center') + fig = figure_list[0] + elif type == 'legend': + legend_labels = list(cmap.keys()) + labels_len = [len(label) for label in legend_labels] + max_label_len = max(labels_len) + figure_width = 13 * max_label_len + if figure_width < 80: + figure_width = 80 + figure_width = int(figure_height / 500 * figure_width) + for i in range(legend_cols): + # legend = Legend(location='top_left', orientation='vertical', border_line_color=None, + # label_text_font_size=f'{int(figure_height / 500 * 13)}px', + # label_height=int(figure_height / 500 * 20), label_width=int(figure_height / 500 * 20)) + legend = Legend(location='top_left', orientation='vertical', border_line_color=None) + labels = legend_labels[i * legend_items_in_col: (i + 1) * legend_items_in_col] + for label in labels: + color = cmap[label] + legend.items.append(LegendItem(label=label, renderers=[figure_list[i].circle(x=[0], y=[0], color=color, size=0)])) + figure_list[i].add_layout(legend, 'left') + figure_list[i].width = figure_width + if len(figure_list) == 1: + fig = figure_list[0] + else: + fig = row(*figure_list, sizing_mode='fixed') + + self.colorbar_or_legend = type + + return fig + + def _create_base_image_figure(self): + figure = None + if self.base_image is not None: + if self.base_image_points is None: + self.base_image_points = self._create_base_image_xarray() + if len(self.base_image_points.shape) == 2: + figure = self.base_image_points.hvplot( + cmap='gray', cnorm='eq_hist', hover=False, colorbar=False, + datashade=True, aggregator='mean', dynspread=True + # rasterize=True, aggregator='mean', dynspread=True + ).opts( + bgcolor=self.bgcolor, + width=self.width, + height=self.height, + xaxis=None, + yaxis=None, + invert_yaxis=True + ) + else: + figure = self.base_image_points.hvplot.rgb( + x='x', y='y', bands='channel', hover=False, + datashade=True, aggregator='mean', dynspread=True + # rasterize=True, aggregator='mean', dynspread=True + ).opts( + bgcolor=self.bgcolor, + width=self.width, + height=self.height, + xaxis='bare', + yaxis='bare', + invert_yaxis=True + ) + return figure + + def __rangexy_callback(self, x_range, y_range): + self.x_range = x_range + self.y_range = y_range def show(self): assert self.data.cells.cell_border is not None @@ -281,42 +634,59 @@ def _create_figure(cm_key_continuous_value, cm_key_discrete_value, color_by_valu self.cluster_colorpicker.visible = False self.reverse_colormap.disabled = False if self.color_map_key_discrete.visible is True: - cm_key_value = 'stereo' + cm_key_value = 'stereo' if self.last_cm_key_continuous is None else self.last_cm_key_continuous + # cm_key_value = self.palette if self.last_cm_key_continuous is None else self.last_cm_key_continuous self.color_map_key_continuous.visible = True + self.color_map_key_continuous.value = cm_key_value self.color_map_key_discrete.visible = False else: cm_key_value = cm_key_continuous_value cmap = stereo_conf.linear_colors(cm_key_value, reverse=reverse_colormap_value) + self.last_cm_key_continuous = cm_key_value else: self.cluster.visible = True self.cluster_colorpicker.visible = True + self.reverse_colormap.disabled = True if self.color_map_key_continuous.visible is True: + # cm_key_value = 'stereo_30' if self.last_cm_key_discrete is None else self.last_cm_key_discrete + cm_key_value = self.palette if self.last_cm_key_discrete is None else self.last_cm_key_discrete self.color_map_key_continuous.visible = False self.color_map_key_discrete.visible = True - - self.cluster_color_map[self.cluster.value] = cluster_colorpicker_value - cmap = list(self.cluster_color_map.values()) + self.color_map_key_discrete.value = cm_key_value + else: + cm_key_value = cm_key_discrete_value + + if cm_key_value != self.last_cm_key_discrete: + n = len(self.cluster_id) + colors = stereo_conf.get_colors(cm_key_value, n) + self.cluster_color_map = OrderedDict({k: v for k, v in zip(self.cluster_id, colors)}) + default_cluster_id = self.cluster_id[0] + default_cluster_color = self.cluster_color_map[default_cluster_id] + self.cluster.value = default_cluster_id + self.cluster_colorpicker.value = default_cluster_color + else: + self.cluster_color_map[self.cluster.value] = cluster_colorpicker_value + # cmap = list(self.cluster_color_map.values()) + cmap = self.cluster_color_map + self.last_cm_key_discrete = cm_key_value if self.cluster_res is None or color_by_value != 'cluster': - color = 'color' - show_legend = False - colorbar = True + self.figure_colorbar_legend = self._create_colorbar_or_legend('colorbar', cmap, polygons_detail) else: - color = hv.dim('color').categorize(self.cluster_color_map) - show_legend = True - colorbar = False + self.figure_colorbar_legend = self._create_colorbar_or_legend('legend', cmap) + self.figure_polygons = polygons_detail.hvplot.polygons( 'polygons', hover_cols=vdims ).opts( bgcolor=self.bgcolor, - color=color, + color='color', cnorm='eq_hist', cmap=cmap, width=self.width, height=self.height, - xaxis='bare', - yaxis='bare', + xaxis=None, + yaxis=None, invert_yaxis=True, line_width=1, line_alpha=0, @@ -325,41 +695,21 @@ def _create_figure(cm_key_continuous_value, cm_key_discrete_value, color_by_valu hover_fill_alpha=self.hover_fg_alpha, active_tools=['wheel_zoom'], tools=[hover_tool], - show_legend=show_legend, - colorbar=colorbar + show_legend=False, + colorbar=False ) if self.base_image is not None: - base_image_points_detail = self._create_base_image_xarray() - if len(base_image_points_detail.shape) == 2: - self.figure_points = base_image_points_detail.hvplot( - cmap='gray', cnorm='eq_hist', hover=False, colorbar=False, - # datashade=True, dynspread=True - rasterize=True, aggregator='mean', dynspread=True - ).opts( - bgcolor=self.bgcolor, - width=self.width, - height=self.height, - xaxis='bare', - yaxis='bare', - invert_yaxis=True - ) - else: - self.figure_points = base_image_points_detail.hvplot.rgb( - x='x', y='y', bands='channel', hover=False, - rasterize=True, aggregator='mean', dynspread=True - ).opts( - bgcolor=self.bgcolor, - width=self.width, - height=self.height, - xaxis='bare', - yaxis='bare', - invert_yaxis=True - ) + if self.figure_points is None: + self.figure_points = self._create_base_image_figure() figure = self.figure_points * self.figure_polygons else: figure = self.figure_polygons - return figure + + self.rangexy_stream = hv.streams.RangeXY(source=self.figure_polygons) + self.rangexy_stream.add_subscriber(self.__rangexy_callback) + + return pn.Row(figure, self.figure_colorbar_legend) return pn.Row( pn.Column(_create_figure), @@ -369,6 +719,14 @@ def _create_figure(cm_key_continuous_value, cm_key_discrete_value, color_by_valu self.color_by, self.reverse_colormap, pn.Row(self.cluster, self.cluster_colorpicker), - self.gene_names + self.gene_names, + '', + self.save_title, + self.save_file_name, + pn.Row(self.save_file_width, self.save_file_hight), + self.save_only_in_view, + # self.with_base_image, + self.save_button, + self.save_message, ) ) diff --git a/stereo/plots/plot_collection.py b/stereo/plots/plot_collection.py index c99d1de3..794ded9c 100644 --- a/stereo/plots/plot_collection.py +++ b/stereo/plots/plot_collection.py @@ -206,7 +206,7 @@ def marker_genes_volcano( group_name: str, res_key: Optional[str] = 'marker_genes', hue_order: Optional[set] = ('down', 'normal', 'up'), - colors: Optional[str] = ("#377EB8", "grey", "#E41A1C"), + palette: Optional[str] = ("#377EB8", "grey", "#E41A1C"), alpha: Optional[int] = 1, dot_size: Optional[int] = 15, text_genes: Optional[list] = None, @@ -225,7 +225,7 @@ def marker_genes_volcano( :param group_name: the group name. :param res_key: the result key of marker gene. :param hue_order: the classification method. - :param colors: the color set. + :param palette: the color theme. :param alpha: the opacity. :param dot_size: the dot size. :param text_genes: show gene names. @@ -248,7 +248,7 @@ def marker_genes_volcano( cut_off_pvalue=cut_off_pvalue, cut_off_logFC=cut_off_logFC, hue_order=hue_order, - palette=colors, + palette=palette, alpha=alpha, s=dot_size, x_label=x_label, y_label=y_label, vlines=vlines, @@ -284,7 +284,7 @@ def genes_count( :param out_dpi: the dpi when the figure is saved. """ # noqa import math - import matplotlib.pyplot as plt + # import matplotlib.pyplot as plt from matplotlib import gridspec set_xy_empty = False if x_label == y_label == '' or x_label == y_label == []: @@ -587,9 +587,9 @@ def violin( :param rotation_angle: rotation of xtick labels. :param group_by: the key of the observation grouping to consider. :param multi_panel: Display keys in multiple panels also when groupby is not None. - :param scale: The method used to scale the width of each violin. If ‘width’ (the default), each violin will - have the same width. If ‘area’, each violin will have the same area. - If ‘count’, a violin’s width corresponds to the number of observations. + :param scale: The method used to scale the width of each violin. If 'width' (the default), each violin will + have the same width. If 'area', each violin will have the same area. + If 'count', a violin's width corresponds to the number of observations. :param ax: a matplotlib axes object. only works if plotting a single component. :param order: Order in which to show the categories. :param use_raw: Whether to use raw attribute of data. Defaults to True if .raw is present. @@ -668,7 +668,7 @@ def batches_umap( y_label: Optional[str] = 'umap2', bfig_title: Optional[str] = 'all batches', dot_size: Optional[int] = 1, - colors: Optional[Union[str, list]] = 'stereo_30', + palette: Optional[Union[str, list]] = 'stereo_30', width: Optional[int] = None, height: Optional[int] = None ): @@ -681,7 +681,7 @@ def batches_umap( :param y_label: the y label. :param bfig_title: the big figure title. :param dot_size: the dot size. - :param colors: the color list. + :param palette: the color list. :param width: the figure width in pixels. :param height: the figure height in pixels. @@ -707,7 +707,7 @@ def batches_umap( umap_res['batch'] = self.data.cells.batch.astype(np.uint16) batch_number_unique = np.unique(umap_res['batch']) batch_count = len(batch_number_unique) - cmap = stereo_conf.get_colors(colors, batch_count) + cmap = stereo_conf.get_colors(palette, batch_count) fig_all = umap_res.hvplot.scatter( x='x', y='y', c='batch', cmap=cmap, cnorm='eq_hist', ).opts( @@ -769,7 +769,6 @@ def umap( x_label: Optional[Union[str, list]] = 'umap1', y_label: Optional[Union[str, list]] = 'umap2', dot_size: Optional[int] = None, - colors: Optional[Union[str, list]] = 'stereo', width: Optional[int] = None, height: Optional[int] = None, palette: Optional[int] = None, @@ -787,7 +786,6 @@ def umap( :param x_label: the x label. :param y_label: the y label. :param dot_size: the dot size. - :param colors: the color list. :param width: the figure width in pixels. :param height: the figure height in pixels. :param palette: color theme. @@ -798,13 +796,15 @@ def umap( """ # noqa res = self.check_res_key(res_key) + if palette is None: + palette = 'stereo_30' if cluster_key else 'stereo' if cluster_key: cluster_res = self.check_res_key(cluster_key) n = len(set(cluster_res['group'])) if title is None: title = cluster_key - if not palette: - palette = stereo_conf.get_colors('stereo_30' if colors == 'stereo' else colors, n) + # if not palette: + # palette = stereo_conf.get_colors('stereo_30' if colors == 'stereo' else colors, n) return base_scatter( res.values[:, 0], res.values[:, 1], @@ -828,7 +828,7 @@ def umap( res.values[:, 0], res.values[:, 1], hue=self.data.sub_exp_matrix_by_name(gene_name=gene_names).T, - palette=colors, + palette=palette, title=gene_names if title is None else title, x_label=[x_label for i in range(len(gene_names))], y_label=[y_label for i in range(len(gene_names))], @@ -1298,28 +1298,38 @@ def scenic_clustermap( @reorganize_coordinate def cells_plotting( self, - color_by: Literal['total_count', 'n_genes_by_counts', 'gene', 'cluster'] = 'total_count', + color_by: Literal['total_counts', 'n_genes_by_counts', 'gene', 'cluster'] = 'total_counts', color_key: Optional[str] = None, bgcolor: Optional[str] = '#2F2F4F', + palette: Optional[Union[str, list, dict]] = None, width: Optional[int] = None, height: Optional[int] = None, fg_alpha: Optional[float] = 0.5, base_image: Optional[str] = None, - base_im_to_gray: bool = False + base_im_to_gray: bool = False, + use_raw: bool = True, + show: bool = True ): """Plot the cells. - :param color_by: spcify the way of coloring, default to 'total_count'. + :param color_by: spcify the way of coloring, default to 'total_counts'. if set to 'gene', you need to specify a gene name by `color_key`. if set to 'cluster', you need to specify the key to get cluster result by `color_key`. - :param color_key: the key to get the data to color the plot, it is ignored when the `color_by` is set to 'total_count' or 'n_genes_by_counts'. + :param color_key: the key to get the data to color the plot, it is ignored when the `color_by` is set to 'total_counts' or 'n_genes_by_counts'. :param bgcolor: set background color. + :param palette: color theme, + when `color_by` is 'cluster', it can be a palette name, a list of colors or a dictionary of colors whose keys is the cluster name, + when other `color_by` is set, it only can be a string of palette name. :param width: the figure width in pixels. :param height: the figure height in pixels. - :param fg_alpha: the alpha of foreground image, between 0 and 1, defaults to 0.5 + :param fg_alpha: the transparency of foreground image, between 0 and 1, defaults to 0.5 this is the colored image of the cells. :param base_image: the path of the ssdna image after calibration, defaults to None it will be located behide the image of the cells. + :param base_im_to_gray: whether to convert the base image to gray scale if base image is RGB/RGBA image. + :param use_raw: whether to use raw data, defaults to True if .raw is present. + :param show: show the figure directly or get the figure object, defaults to True. + If set to False, you need to call the `show` method of the figure object to show the figure. :param reorganize_coordinate: if the data is merged from several slices, whether to reorganize the coordinates of the obs(cells), if set it to a number, like 2, the coordinates will be reorganized to 2 columns on coordinate system as below: --------------- @@ -1331,7 +1341,28 @@ def cells_plotting( if set it to `False`, the coordinates will not be changed. :param horizontal_offset_additional: the additional offset between each slice on horizontal direction while reorganizing coordinates. :param vertical_offset_additional: the additional offset between each slice on vertical direction while reorganizing coordinates. - :return: Cells distribution figure. + + .. note:: + + Exporting + ------------------ + + This plot can be exported as PNG and SVG, then converted to PDF. + + You need the following necessary dependencies to support exporting: + + conda install -c conda-forge selenium firefox geckodriver cairosvg + + On Linux systems, you may need to install some additional libraries to support the above dependencies, + for example, on Ubuntu, you need to install the following libraries: + + sudo apt-get install libgtk-3-dev libasound2-dev + + On other linux systems, you may need to install the corresponding libraries according to the error message. + + There are two ways to export the plot, one is to manupulate on browser, + another is to call the method `save_plot `_ of this figure object. + """ # noqa from .plot_cells import PlotCells if color_by in ('cluster', 'gene'): @@ -1343,13 +1374,17 @@ def cells_plotting( color_key=color_key, # cluster_res_key=cluster_res_key, bgcolor=bgcolor, + palette=palette, width=width, height=height, fg_alpha=fg_alpha, base_image=base_image, - base_im_to_gray=base_im_to_gray + base_im_to_gray=base_im_to_gray, + use_raw=use_raw ) - return pc.show() + if show: + return pc.show() + return pc @download def correlation_heatmap( diff --git a/stereo/preprocess/qc.py b/stereo/preprocess/qc.py index d61a2ae2..1b3ff3d0 100644 --- a/stereo/preprocess/qc.py +++ b/stereo/preprocess/qc.py @@ -27,7 +27,8 @@ def cal_cells_indicators(data): exp_matrix = data.exp_matrix data.cells.total_counts = cal_total_counts(exp_matrix) data.cells.n_genes_by_counts = cal_n_genes_by_counts(exp_matrix) - data.cells.pct_counts_mt = cal_pct_counts_mt(data) + # data.cells.pct_counts_mt = cal_pct_counts_mt(data) + data.cells.pct_counts_mt = cal_pct_counts_mt(exp_matrix, data.gene_names) return data @@ -35,7 +36,8 @@ def cal_genes_indicators(data): exp_matrix = data.exp_matrix data.genes.n_cells = cal_n_cells(exp_matrix) data.genes.n_counts = cal_per_gene_counts(exp_matrix) - data.genes.mean_umi = cal_gene_mean_umi(data) + # data.genes.mean_umi = cal_gene_mean_umi(data) + data.genes.mean_umi = cal_gene_mean_umi(exp_matrix) return data @@ -79,9 +81,20 @@ def cal_n_cells(exp_matrix): return exp_matrix.getnnz(axis=0) if issparse(exp_matrix) else np.count_nonzero(exp_matrix, axis=0) -def cal_gene_mean_umi(data): +# def cal_gene_mean_umi(data): +# old_settings = np.seterr(divide='ignore', invalid='ignore') +# gene_mean_umi = data.genes.n_counts / data.genes.n_cells +# flag = np.isnan(gene_mean_umi) | np.isinf(gene_mean_umi) +# gene_mean_umi[flag] = 0 +# np.seterr(**old_settings) +# return gene_mean_umi + +def cal_gene_mean_umi(exp_matrix): old_settings = np.seterr(divide='ignore', invalid='ignore') - gene_mean_umi = data.genes.n_counts / data.genes.n_cells + # gene_mean_umi = data.genes.n_counts / data.genes.n_cells + n_counts = cal_per_gene_counts(exp_matrix) + n_cells = cal_n_cells(exp_matrix) + gene_mean_umi = n_counts / n_cells flag = np.isnan(gene_mean_umi) | np.isinf(gene_mean_umi) gene_mean_umi[flag] = 0 np.seterr(**old_settings) @@ -92,13 +105,24 @@ def cal_n_genes_by_counts(exp_matrix): return exp_matrix.getnnz(axis=1) if issparse(exp_matrix) else np.count_nonzero(exp_matrix, axis=1) -def cal_pct_counts_mt(data): +# def cal_pct_counts_mt(data): +# old_settings = np.seterr(divide='ignore', invalid='ignore') +# if data.cells.total_counts is None: +# data.cells.total_counts = cal_total_counts(data.exp_matrix) +# mt_index = np.char.startswith(np.char.lower(data.gene_names), prefix='mt-') +# mt_count = np.array(data.exp_matrix[:, mt_index].sum(1)).reshape(-1) +# pct_counts_mt = mt_count / data.cells.total_counts * 100 +# flag = np.isnan(pct_counts_mt) | np.isinf(pct_counts_mt) +# pct_counts_mt[flag] = 0 +# np.seterr(**old_settings) +# return pct_counts_mt + +def cal_pct_counts_mt(exp_matrix, gene_names): old_settings = np.seterr(divide='ignore', invalid='ignore') - if data.cells.total_counts is None: - data.cells.total_counts = cal_total_counts(data.exp_matrix) - mt_index = np.char.startswith(np.char.lower(data.gene_names), prefix='mt-') - mt_count = np.array(data.exp_matrix[:, mt_index].sum(1)).reshape(-1) - pct_counts_mt = mt_count / data.cells.total_counts * 100 + total_counts = cal_total_counts(exp_matrix) + mt_index = np.char.startswith(np.char.lower(gene_names), prefix='mt-') + mt_count = np.array(exp_matrix[:, mt_index].sum(1)).reshape(-1) + pct_counts_mt = mt_count / total_counts * 100 flag = np.isnan(pct_counts_mt) | np.isinf(pct_counts_mt) pct_counts_mt[flag] = 0 np.seterr(**old_settings)