diff --git a/stereo/plots/plot_cells.py b/stereo/plots/plot_cells.py
index 12e74fb6..96035c2f 100644
--- a/stereo/plots/plot_cells.py
+++ b/stereo/plots/plot_cells.py
@@ -13,23 +13,27 @@
from natsort import natsorted
from stereo.stereo_config import stereo_conf
+from stereo.core.stereo_exp_data import StereoExpData
+from stereo.preprocess.qc import cal_total_counts, cal_n_genes_by_counts, cal_pct_counts_mt
class PlotCells:
def __init__(
self,
data,
- color_by='total_count',
+ color_by='total_counts',
color_key=None,
# cluster_res_key='cluster',
bgcolor='#2F2F4F',
+ palette=None,
width=None,
height=None,
fg_alpha=0.5,
base_image=None,
- base_im_to_gray=False
+ base_im_to_gray=False,
+ use_raw=True
):
- self.data = data
+ self.data: StereoExpData = data
# if cluster_res_key in self.data.tl.result:
# res = self.data.tl.result[cluster_res_key]
# self.cluster_res = np.array(res['group'])
@@ -41,29 +45,42 @@ def __init__(
# self.cluster_res = None
# self.cluster_id = []
if color_by != 'cluster':
+ self.palette = 'stereo' if palette is None else palette
+ assert isinstance(self.palette, str), f'The palette must be a name of palette when color_by is {color_by}'
self.cluster_res = None
self.cluster_id = []
if color_by == 'gene':
if not np.any(np.isin(self.data.genes.gene_name, color_key)):
raise ValueError(f'The gene {color_key} is not found.')
else:
+ self.palette = 'stereo_30' if palette is None else 'custom'
if color_key in self.data.tl.result:
res = self.data.tl.result[color_key]
self.cluster_res = np.array(res['group'])
self.cluster_id = natsorted(np.unique(self.cluster_res).tolist())
n = len(self.cluster_id)
- cmap = stereo_conf.get_colors('stereo_30', n)
+ # if isinstance(palette, str):
+ # cmap = stereo_conf.get_colors(self.palette, n)
+ # self.cluster_color_map = OrderedDict({k: v for k, v in zip(self.cluster_id, cmap)})
+ if isinstance(palette, (list, np.ndarray)):
+ stereo_conf.palette_custom = list(palette)
+ elif isinstance(palette, dict):
+ stereo_conf.palette_custom = [palette[k] for k in self.cluster_id if k in palette]
+ cmap = stereo_conf.get_colors(self.palette, n)
self.cluster_color_map = OrderedDict({k: v for k, v in zip(self.cluster_id, cmap)})
else:
self.cluster_res = None
self.cluster_id = []
+ self.last_cm_key_continuous = None
+ self.last_cm_key_discrete = None
self.color_by_input = color_by
self.color_key = color_key
self.bgcolor = bgcolor
self.width, self.height = self._set_width_and_height(width, height)
self.fg_alpha = fg_alpha
self.base_image = base_image
+ self.base_image_points = None
if self.fg_alpha < 0:
self.fg_alpha = 0.3
@@ -76,7 +93,15 @@ def __init__(
self.hover_fg_alpha = self.fg_alpha / 2
self.figure_polygons = None
self.figure_points = None
+ self.figure_colorbar_legend = None
+ self.colorbar_or_legend = None
self.base_im_to_gray = base_im_to_gray
+ self.use_raw = use_raw and self.data.raw is not None
+ self.rangexy_stream = None
+ self.x_range = None
+ self.y_range = None
+ self.firefox_path = None
+ self.driver_path = None
def _set_width_and_height(self, width, height):
if width is None or height is None:
@@ -115,14 +140,11 @@ def _create_base_image_xarray(self):
image_xarray = None
with tiff.TiffFile(self.base_image) as tif:
- # image_data = tiff.imread(self.base_image)
image_data = tif.asarray()
if len(image_data.shape) == 3 and self.base_im_to_gray:
from cv2 import cvtColor, COLOR_BGR2GRAY
image_data = cvtColor(image_data[:, :, [2, 1, 0]], COLOR_BGR2GRAY)
if len(image_data.shape) == 3 and image_data.dtype == np.uint16:
- # from stereo.image.tissue_cut.tissue_cut_utils.tissue_seg_utils import transfer_16bit_to_8bit
- # image_data = transfer_16bit_to_8bit(image_data)
from matplotlib.colors import Normalize
if tif.shaped_metadata is not None:
metadata = tif.shaped_metadata[0]
@@ -160,22 +182,41 @@ def _create_polygons(self, color_by):
polygons = []
color = []
position = []
+ if self.use_raw:
+ if self.data.shape != self.data.raw.shape:
+ cells_isin = np.isin(self.data.raw.cell_names, self.data.cell_names)
+ genes_isin = np.isin(self.data.raw.gene_names, self.data.gene_names)
+ exp_matrix = self.data.raw.exp_matrix[cells_isin][:, genes_isin]
+ gene_names = self.data.raw.gene_names[genes_isin]
+ else:
+ exp_matrix = self.data.raw.exp_matrix
+ gene_names = self.data.raw.gene_names
+ else:
+ exp_matrix = self.data.exp_matrix
+ gene_names = self.data.gene_names
+ total_counts = cal_total_counts(exp_matrix)
+ n_genes_by_counts = cal_n_genes_by_counts(exp_matrix)
+ pct_counts_mt = cal_pct_counts_mt(exp_matrix, gene_names)
if color_by == 'gene':
in_bool = np.isin(self.data.genes.gene_name, self.color_key)
for i, cell_border in enumerate(self.data.cells.cell_border):
cell_border = cell_border[cell_border[:, 0] < 32767] + self.data.position[i]
cell_border = cell_border.reshape((-1,)).tolist()
polygons.append([cell_border])
- if color_by == 'total_count':
- color.append(self.data.cells.total_counts[i])
+ if color_by == 'total_counts':
+ # color.append(self.data.cells.total_counts[i])
+ color.append(total_counts[i])
elif color_by == 'n_genes_by_counts':
- color.append(self.data.cells.n_genes_by_counts[i])
+ # color.append(self.data.cells.n_genes_by_counts[i])
+ color.append(n_genes_by_counts[i])
elif color_by == 'gene':
- color.append(self.data.exp_matrix[i, in_bool].sum())
+ color.append(exp_matrix[i, in_bool].sum())
elif color_by == 'cluster':
- color.append(self.cluster_res[i] if self.cluster_res is not None else self.data.cells.total_counts[i])
+ # color.append(self.cluster_res[i] if self.cluster_res is not None else self.data.cells.total_counts[i])
+ color.append(self.cluster_res[i] if self.cluster_res is not None else total_counts[i])
else:
- color.append(self.data.cells.total_counts[i])
+ # color.append(self.data.cells.total_counts[i])
+ color.append(total_counts[i])
position.append(str(tuple(self.data.position[i].astype(np.uint32))))
polygons = spd.geometry.PolygonArray(polygons)
@@ -183,9 +224,12 @@ def _create_polygons(self, color_by):
'polygons': polygons,
'color': color,
'position': position,
- 'total_counts': self.data.cells.total_counts.astype(np.uint32),
- 'pct_counts_mt': self.data.cells.pct_counts_mt,
- 'n_genes_by_counts': self.data.cells.n_genes_by_counts.astype(np.uint32),
+ # 'total_counts': self.data.cells.total_counts.astype(np.uint32),
+ # 'pct_counts_mt': self.data.cells.pct_counts_mt,
+ # 'n_genes_by_counts': self.data.cells.n_genes_by_counts.astype(np.uint32),
+ 'total_counts': total_counts,
+ 'pct_counts_mt': pct_counts_mt,
+ 'n_genes_by_counts': n_genes_by_counts,
'cluster_id': np.zeros_like(self.data.cell_names) if self.cluster_res is None else self.cluster_res
})
@@ -203,11 +247,17 @@ def _create_polygons(self, color_by):
return polygons_detail, hover_tool, vdims
def _create_widgets(self):
+ # self.color_map_key_continuous = pn.widgets.Select(
+ # value='stereo', options=list(stereo_conf.linear_colormaps.keys()), name='color theme', width=200
+ # )
+ # self.color_map_key_discrete = pn.widgets.Select(
+ # value='stereo_30', options=list(stereo_conf.colormaps.keys()), name='color theme', width=200
+ # )
self.color_map_key_continuous = pn.widgets.Select(
- value='stereo', options=list(stereo_conf.linear_colormaps.keys()), name='color theme', width=200
+ value=self.palette, options=list(stereo_conf.linear_colormaps.keys()), name='color theme', width=200
)
self.color_map_key_discrete = pn.widgets.Select(
- value='stereo_30', options=list(stereo_conf.colormaps.keys()), name='color theme', width=200
+ value=self.palette, options=list(stereo_conf.colormaps.keys()), name='color theme', width=200
)
if self.cluster_res is None:
self.color_map_key_discrete.visible = False
@@ -233,25 +283,328 @@ def _create_widgets(self):
self.cluster_colorpicker = pn.widgets.ColorPicker(
name='cluster color', width=70, disabled=True, visible=False
)
- color_by_key.extend(['total_count', 'n_genes_by_counts', 'gene'])
- # self.color_by = pn.widgets.Select(name='color by', options=color_by_key, value=color_by_key[0], width=200)
+ color_by_key.extend(['total_counts', 'n_genes_by_counts', 'gene'])
self.color_by = pn.widgets.Select(name='color by', options=color_by_key, value=self.color_by_input, width=200)
- # if self.color_by_input == 'gene':
- # gene_names_selector_value = [self.color_key] if isinstance(self.color_key, str) else self.color_key
- # else:
- # i = np.argmax(self.data.genes.n_counts)
- # gene_names_selector_value = [self.data.genes.gene_name[i]]
- # if isinstance(gene_names_selector_value, np.ndarray):
- # gene_names_selector_value = gene_names_selector_value.tolist()
- # self.gene_names = pn.widgets.MultiSelect(name='gene names', value=gene_names_selector_value, options=self.data.genes.gene_name.tolist(), size=10, width=200)
if self.color_by_input == 'gene':
gene_names_selector_value = self.color_key
else:
i = np.argmax(self.data.genes.n_counts)
gene_names_selector_value = self.data.genes.gene_name[i]
- self.gene_names = pn.widgets.Select(name='gene names', value=gene_names_selector_value,
- options=self.data.genes.gene_name.tolist(), size=10, width=200)
+ gene_idx_sorted = np.argsort(self.data.genes.n_counts * -1)
+ gene_names_sorted = self.data.genes.gene_name[gene_idx_sorted].tolist()
+ self.gene_names = pn.widgets.Select(name='genes', value=gene_names_selector_value,
+ options=gene_names_sorted, size=10, width=200)
+
+ self.save_title = pn.widgets.StaticText(name='', value='Save Plot', width=200)
+ self.save_file_name = pn.widgets.TextInput(name='file name(.png, .svg or .pdf)', width=200)
+ self.save_file_width = pn.widgets.IntInput(name='width', value=self.width, width=95)
+ self.save_file_hight = pn.widgets.IntInput(name='height', value=self.height, width=95)
+ self.save_button = pn.widgets.Button(name='save', button_type="primary", width=100)
+ self.save_only_in_view = pn.widgets.Checkbox(name='only in view', value=False, width=100)
+ self.with_base_image = pn.widgets.Checkbox(name='with base image', value=False, width=100)
+ self.save_button.on_click(self._save_button_callback)
+ self.save_message = pn.widgets.StaticText(name='', value='', width=400)
+
+ def _set_firefox_and_driver_path(self):
+ from sys import executable
+ from os import environ
+ import platform
+
+ os_type = platform.system().lower()
+ executable_dir = os.path.dirname(executable)
+ environ_path = environ['PATH'].split(os.pathsep)
+ if executable_dir not in environ_path:
+ environ_path = [executable_dir] + environ_path
+ if os_type == 'windows':
+ bin_paths = [
+ executable_dir,
+ os.path.join(executable_dir, 'Scripts'),
+ os.path.join(executable_dir, 'Library', 'bin'),
+ os.path.join(executable_dir, 'Library', 'mingw-w64', 'bin'),
+ os.path.join(executable_dir, 'Library', 'usr', 'bin'),
+ os.path.join(executable_dir, 'bin')
+ ]
+ for bin_path in bin_paths:
+ if bin_path not in environ_path:
+ environ_path = [bin_path] + environ_path
+ a_path = os.path.join(bin_path, 'firefox.exe')
+ if os.path.exists(a_path) and self.firefox_path is None:
+ self.firefox_path = a_path
+ a_path = os.path.join(bin_path, 'geckodriver.exe')
+ if os.path.exists(a_path) and self.driver_path is None:
+ self.driver_path = a_path
+ elif os_type == 'linux':
+ if self.firefox_path is None:
+ self.firefox_path = os.path.join(executable_dir, 'firefox')
+ if self.driver_path is None:
+ self.driver_path = os.path.join(executable_dir, 'geckodriver')
+ else:
+ raise ValueError(f'The operating system {os_type} is not supported.')
+ environ['PATH'] = os.pathsep.join(environ_path)
+
+ def save_plot(
+ self,
+ save_file_name: str,
+ save_width: int = None,
+ save_height: int = None,
+ save_only_in_view: bool = False,
+ with_base_image: bool = False
+ ):
+ """
+ Save the plot to a PNG, SVG or PDF file depending on the extension of the file name.
+
+ :param save_file_name: the name of the file to save the plot.
+ :param save_width: the width of the saved plot, defaults to be the same as the plot.
+ :param save_height: the height of the saved plot, defaults to be the same as the plot.
+ :param save_only_in_view: only save the plot in the view, defaults to False.
+ :param with_base_image: whether to save the plot with the base image, defaults to False.
+ Currently, the resolution of the saved base image may not be high.
+
+ """
+ self._set_firefox_and_driver_path()
+
+ from selenium import webdriver
+ from selenium.webdriver import FirefoxOptions, FirefoxService
+ from bokeh.io import export_png, export_svg
+ from bokeh.io.export import get_svg
+ from bokeh.layouts import row, Row
+ from cairosvg import svg2pdf
+
+ if save_file_name == '':
+ raise ValueError('Please input the file name.')
+
+ if not save_file_name.lower().endswith('.png') and \
+ not save_file_name.lower().endswith('.svg') and \
+ not save_file_name.lower().endswith('.pdf'):
+ raise ValueError('Only PNG, SVG and PDF files are supported.')
+
+ save_width = self.width if save_width is None else save_width
+ save_height = self.height if save_height is None else save_height
+ if with_base_image:
+ figure_points = self._create_base_image_figure()
+ else:
+ figure_points = None
+ with_base_image = with_base_image and figure_points is not None
+
+ file_name_prefix, file_name_suffix = os.path.splitext(save_file_name)
+ save_file_name = f"{file_name_prefix}_cells_plotting{file_name_suffix}"
+
+ current_x_range = self.x_range
+ current_y_range = self.y_range
+ try:
+ if save_only_in_view:
+ if self.x_range is not None and self.y_range is not None:
+ figure_polygons = self.figure_polygons[self.x_range[0]:self.x_range[1], self.y_range[0]:self.y_range[1]]
+ if with_base_image:
+ figure_points = figure_points.select(x=self.x_range, y=self.y_range)
+ else:
+ figure_polygons = self.figure_polygons
+ else:
+ figure_polygons = self.figure_polygons
+
+ if not with_base_image:
+ output_render = hv.render(figure_polygons, backend='bokeh')
+ for renderer in output_render.renderers:
+ renderer.glyph.fill_alpha = 1
+ else:
+ output_render = hv.render(figure_points, backend='bokeh')
+ polygons_render = hv.render(figure_polygons, backend='bokeh')
+ output_render.renderers.extend(polygons_render.renderers)
+ output_render.output_backend = 'svg'
+ output_render.toolbar_location = None
+ output_render.border_fill_color = None
+ output_render.outline_line_color = 'gray'
+ output_render.xaxis.visible = False
+ output_render.yaxis.visible = False
+ output_render.width = save_width
+ output_render.height = save_height
+
+ if save_height == self.height:
+ figure_colorbar_legend = self.figure_colorbar_legend
+ else:
+ figure_colorbar_legend = self._create_colorbar_or_legend(
+ self.colorbar_or_legend,
+ self.figure_polygons.opts['cmap'],
+ self.figure_polygons.data,
+ save_height
+ )
+ if isinstance(figure_colorbar_legend, Row):
+ for f in figure_colorbar_legend.children:
+ f.output_backend = 'svg'
+ f.height = save_height if save_height > 500 else 500
+ else:
+ figure_colorbar_legend.output_backend = 'svg'
+ figure_colorbar_legend.height = save_height if save_height > 500 else 500
+
+ to_save_instance = row(output_render, figure_colorbar_legend)
+
+ opts = FirefoxOptions()
+ opts.add_argument("--headless")
+ opts.binary_location = self.firefox_path
+ service = FirefoxService(executable_path=self.driver_path)
+ with webdriver.Firefox(options=opts, service=service) as driver:
+ if save_file_name.lower().endswith('png'):
+ export_png(to_save_instance, filename=save_file_name, webdriver=driver, timeout=86400)
+ elif save_file_name.lower().endswith('svg'):
+ export_svg(to_save_instance, filename=save_file_name, webdriver=driver, timeout=86400)
+ elif save_file_name.lower().endswith('pdf'):
+ svg = get_svg(to_save_instance, driver=driver, timeout=86400)[0]
+ svg2pdf(bytestring=svg, write_to=save_file_name)
+ except Exception as e:
+ raise e
+ finally:
+ self.rangexy_stream.event(x_range=current_x_range, y_range=current_y_range)
+ if figure_colorbar_legend is self.figure_colorbar_legend:
+ if isinstance(figure_colorbar_legend, Row):
+ for f in figure_colorbar_legend.children:
+ f.height = self.height if self.height > 500 else 500
+ else:
+ figure_colorbar_legend.height = self.height if self.height > 500 else 500
+ return save_file_name
+
+
+ def _save_button_callback(self, _):
+ """
+ apt-get install libgtk-3-dev libasound2-dev
+ conda install -c conda-forge selenium firefox geckodriver cairosvg
+ """
+ self.save_button.loading = True
+ self.save_message.value = ''
+ try:
+ save_file_name = self.save_plot(
+ self.save_file_name.value,
+ self.save_file_width.value,
+ self.save_file_hight.value,
+ save_only_in_view=self.save_only_in_view.value,
+ with_base_image=self.with_base_image.value,
+ )
+ self.save_message.value = f'The plot has been saved to {save_file_name}.'
+ except ValueError as e:
+ self.save_message.value = f'{str(e)}'
+ except Exception as e:
+ raise e
+ finally:
+ self.save_button.loading = False
+
+
+ def _create_colorbar_or_legend(self, type, cmap, plot_data=None, figure_height=None):
+
+ from bokeh.plotting import figure as bokeh_figure
+ from bokeh.models import (
+ Legend, LegendItem, ColorBar,
+ EqHistColorMapper, BinnedTicker,
+ ColumnDataSource
+ )
+ from bokeh.layouts import row
+
+ figure_height = self.height if figure_height is None else figure_height
+ figure_height = 500 if figure_height < 500 else figure_height
+ if type == 'colorbar':
+ figures = 1
+ else:
+ legend_items_in_col = int(figure_height / 25)
+ legend_counts = len(cmap.keys())
+ legend_cols, legend_left = divmod(legend_counts, legend_items_in_col)
+ if legend_left > 0:
+ legend_cols += 1
+ figures = legend_cols
+
+ figure_list = []
+ for i in range(figures):
+ f = bokeh_figure(width=80, height=figure_height, toolbar_location=None, x_axis_type=None, y_axis_type=None)
+ f.outline_line_color = None
+ f.xgrid.grid_line_color = None
+ f.ygrid.grid_line_color = None
+ figure_list.append(f)
+
+ if type == 'colorbar':
+ min_value = min(plot_data['color'])
+ max_value = max(plot_data['color'])
+ ticks_num = 100
+ ticks_interval = (max_value - min_value) / (ticks_num - 1)
+ ticks = [min(min_value + i * ticks_interval, max_value) for i in range(ticks_num)]
+ color_mapper = EqHistColorMapper(palette=cmap, low=min_value, high=max_value)
+
+ data_source = ColumnDataSource(data={
+ 'x': [0] * len(ticks),
+ 'y': [0] * len(ticks),
+ 'color': ticks
+ })
+ figure_list[0].circle(x='x', y='y', color={'field': 'color', 'transform': color_mapper}, size=0, source=data_source)
+ ticker = BinnedTicker(mapper=color_mapper, num_major_ticks=8)
+ color_bar = ColorBar(color_mapper=color_mapper, location='center_left' if self.height >= 500 else 'top_left',
+ orientation='vertical', height=int(figure_height // 1.5), width=int(figure_height / 500 * 20),
+ major_tick_line_color='black', major_label_text_font_size=f'{int(figure_height / 500 * 11)}px',
+ major_tick_in=int(figure_height / 500 * 5), major_tick_line_width=int(figure_height / 500 * 1),
+ ticker=ticker)
+ figure_list[0].width = int(figure_height / 500 * 150)
+ figure_list[0].add_layout(color_bar, 'center')
+ fig = figure_list[0]
+ elif type == 'legend':
+ legend_labels = list(cmap.keys())
+ labels_len = [len(label) for label in legend_labels]
+ max_label_len = max(labels_len)
+ figure_width = 13 * max_label_len
+ if figure_width < 80:
+ figure_width = 80
+ figure_width = int(figure_height / 500 * figure_width)
+ for i in range(legend_cols):
+ # legend = Legend(location='top_left', orientation='vertical', border_line_color=None,
+ # label_text_font_size=f'{int(figure_height / 500 * 13)}px',
+ # label_height=int(figure_height / 500 * 20), label_width=int(figure_height / 500 * 20))
+ legend = Legend(location='top_left', orientation='vertical', border_line_color=None)
+ labels = legend_labels[i * legend_items_in_col: (i + 1) * legend_items_in_col]
+ for label in labels:
+ color = cmap[label]
+ legend.items.append(LegendItem(label=label, renderers=[figure_list[i].circle(x=[0], y=[0], color=color, size=0)]))
+ figure_list[i].add_layout(legend, 'left')
+ figure_list[i].width = figure_width
+ if len(figure_list) == 1:
+ fig = figure_list[0]
+ else:
+ fig = row(*figure_list, sizing_mode='fixed')
+
+ self.colorbar_or_legend = type
+
+ return fig
+
+ def _create_base_image_figure(self):
+ figure = None
+ if self.base_image is not None:
+ if self.base_image_points is None:
+ self.base_image_points = self._create_base_image_xarray()
+ if len(self.base_image_points.shape) == 2:
+ figure = self.base_image_points.hvplot(
+ cmap='gray', cnorm='eq_hist', hover=False, colorbar=False,
+ datashade=True, aggregator='mean', dynspread=True
+ # rasterize=True, aggregator='mean', dynspread=True
+ ).opts(
+ bgcolor=self.bgcolor,
+ width=self.width,
+ height=self.height,
+ xaxis=None,
+ yaxis=None,
+ invert_yaxis=True
+ )
+ else:
+ figure = self.base_image_points.hvplot.rgb(
+ x='x', y='y', bands='channel', hover=False,
+ datashade=True, aggregator='mean', dynspread=True
+ # rasterize=True, aggregator='mean', dynspread=True
+ ).opts(
+ bgcolor=self.bgcolor,
+ width=self.width,
+ height=self.height,
+ xaxis='bare',
+ yaxis='bare',
+ invert_yaxis=True
+ )
+ return figure
+
+ def __rangexy_callback(self, x_range, y_range):
+ self.x_range = x_range
+ self.y_range = y_range
def show(self):
assert self.data.cells.cell_border is not None
@@ -281,42 +634,59 @@ def _create_figure(cm_key_continuous_value, cm_key_discrete_value, color_by_valu
self.cluster_colorpicker.visible = False
self.reverse_colormap.disabled = False
if self.color_map_key_discrete.visible is True:
- cm_key_value = 'stereo'
+ cm_key_value = 'stereo' if self.last_cm_key_continuous is None else self.last_cm_key_continuous
+ # cm_key_value = self.palette if self.last_cm_key_continuous is None else self.last_cm_key_continuous
self.color_map_key_continuous.visible = True
+ self.color_map_key_continuous.value = cm_key_value
self.color_map_key_discrete.visible = False
else:
cm_key_value = cm_key_continuous_value
cmap = stereo_conf.linear_colors(cm_key_value, reverse=reverse_colormap_value)
+ self.last_cm_key_continuous = cm_key_value
else:
self.cluster.visible = True
self.cluster_colorpicker.visible = True
+ self.reverse_colormap.disabled = True
if self.color_map_key_continuous.visible is True:
+ # cm_key_value = 'stereo_30' if self.last_cm_key_discrete is None else self.last_cm_key_discrete
+ cm_key_value = self.palette if self.last_cm_key_discrete is None else self.last_cm_key_discrete
self.color_map_key_continuous.visible = False
self.color_map_key_discrete.visible = True
-
- self.cluster_color_map[self.cluster.value] = cluster_colorpicker_value
- cmap = list(self.cluster_color_map.values())
+ self.color_map_key_discrete.value = cm_key_value
+ else:
+ cm_key_value = cm_key_discrete_value
+
+ if cm_key_value != self.last_cm_key_discrete:
+ n = len(self.cluster_id)
+ colors = stereo_conf.get_colors(cm_key_value, n)
+ self.cluster_color_map = OrderedDict({k: v for k, v in zip(self.cluster_id, colors)})
+ default_cluster_id = self.cluster_id[0]
+ default_cluster_color = self.cluster_color_map[default_cluster_id]
+ self.cluster.value = default_cluster_id
+ self.cluster_colorpicker.value = default_cluster_color
+ else:
+ self.cluster_color_map[self.cluster.value] = cluster_colorpicker_value
+ # cmap = list(self.cluster_color_map.values())
+ cmap = self.cluster_color_map
+ self.last_cm_key_discrete = cm_key_value
if self.cluster_res is None or color_by_value != 'cluster':
- color = 'color'
- show_legend = False
- colorbar = True
+ self.figure_colorbar_legend = self._create_colorbar_or_legend('colorbar', cmap, polygons_detail)
else:
- color = hv.dim('color').categorize(self.cluster_color_map)
- show_legend = True
- colorbar = False
+ self.figure_colorbar_legend = self._create_colorbar_or_legend('legend', cmap)
+
self.figure_polygons = polygons_detail.hvplot.polygons(
'polygons', hover_cols=vdims
).opts(
bgcolor=self.bgcolor,
- color=color,
+ color='color',
cnorm='eq_hist',
cmap=cmap,
width=self.width,
height=self.height,
- xaxis='bare',
- yaxis='bare',
+ xaxis=None,
+ yaxis=None,
invert_yaxis=True,
line_width=1,
line_alpha=0,
@@ -325,41 +695,21 @@ def _create_figure(cm_key_continuous_value, cm_key_discrete_value, color_by_valu
hover_fill_alpha=self.hover_fg_alpha,
active_tools=['wheel_zoom'],
tools=[hover_tool],
- show_legend=show_legend,
- colorbar=colorbar
+ show_legend=False,
+ colorbar=False
)
if self.base_image is not None:
- base_image_points_detail = self._create_base_image_xarray()
- if len(base_image_points_detail.shape) == 2:
- self.figure_points = base_image_points_detail.hvplot(
- cmap='gray', cnorm='eq_hist', hover=False, colorbar=False,
- # datashade=True, dynspread=True
- rasterize=True, aggregator='mean', dynspread=True
- ).opts(
- bgcolor=self.bgcolor,
- width=self.width,
- height=self.height,
- xaxis='bare',
- yaxis='bare',
- invert_yaxis=True
- )
- else:
- self.figure_points = base_image_points_detail.hvplot.rgb(
- x='x', y='y', bands='channel', hover=False,
- rasterize=True, aggregator='mean', dynspread=True
- ).opts(
- bgcolor=self.bgcolor,
- width=self.width,
- height=self.height,
- xaxis='bare',
- yaxis='bare',
- invert_yaxis=True
- )
+ if self.figure_points is None:
+ self.figure_points = self._create_base_image_figure()
figure = self.figure_points * self.figure_polygons
else:
figure = self.figure_polygons
- return figure
+
+ self.rangexy_stream = hv.streams.RangeXY(source=self.figure_polygons)
+ self.rangexy_stream.add_subscriber(self.__rangexy_callback)
+
+ return pn.Row(figure, self.figure_colorbar_legend)
return pn.Row(
pn.Column(_create_figure),
@@ -369,6 +719,14 @@ def _create_figure(cm_key_continuous_value, cm_key_discrete_value, color_by_valu
self.color_by,
self.reverse_colormap,
pn.Row(self.cluster, self.cluster_colorpicker),
- self.gene_names
+ self.gene_names,
+ '',
+ self.save_title,
+ self.save_file_name,
+ pn.Row(self.save_file_width, self.save_file_hight),
+ self.save_only_in_view,
+ # self.with_base_image,
+ self.save_button,
+ self.save_message,
)
)
diff --git a/stereo/plots/plot_collection.py b/stereo/plots/plot_collection.py
index c99d1de3..794ded9c 100644
--- a/stereo/plots/plot_collection.py
+++ b/stereo/plots/plot_collection.py
@@ -206,7 +206,7 @@ def marker_genes_volcano(
group_name: str,
res_key: Optional[str] = 'marker_genes',
hue_order: Optional[set] = ('down', 'normal', 'up'),
- colors: Optional[str] = ("#377EB8", "grey", "#E41A1C"),
+ palette: Optional[str] = ("#377EB8", "grey", "#E41A1C"),
alpha: Optional[int] = 1,
dot_size: Optional[int] = 15,
text_genes: Optional[list] = None,
@@ -225,7 +225,7 @@ def marker_genes_volcano(
:param group_name: the group name.
:param res_key: the result key of marker gene.
:param hue_order: the classification method.
- :param colors: the color set.
+ :param palette: the color theme.
:param alpha: the opacity.
:param dot_size: the dot size.
:param text_genes: show gene names.
@@ -248,7 +248,7 @@ def marker_genes_volcano(
cut_off_pvalue=cut_off_pvalue,
cut_off_logFC=cut_off_logFC,
hue_order=hue_order,
- palette=colors,
+ palette=palette,
alpha=alpha, s=dot_size,
x_label=x_label, y_label=y_label,
vlines=vlines,
@@ -284,7 +284,7 @@ def genes_count(
:param out_dpi: the dpi when the figure is saved.
""" # noqa
import math
- import matplotlib.pyplot as plt
+ # import matplotlib.pyplot as plt
from matplotlib import gridspec
set_xy_empty = False
if x_label == y_label == '' or x_label == y_label == []:
@@ -587,9 +587,9 @@ def violin(
:param rotation_angle: rotation of xtick labels.
:param group_by: the key of the observation grouping to consider.
:param multi_panel: Display keys in multiple panels also when groupby is not None.
- :param scale: The method used to scale the width of each violin. If ‘width’ (the default), each violin will
- have the same width. If ‘area’, each violin will have the same area.
- If ‘count’, a violin’s width corresponds to the number of observations.
+ :param scale: The method used to scale the width of each violin. If 'width' (the default), each violin will
+ have the same width. If 'area', each violin will have the same area.
+ If 'count', a violin's width corresponds to the number of observations.
:param ax: a matplotlib axes object. only works if plotting a single component.
:param order: Order in which to show the categories.
:param use_raw: Whether to use raw attribute of data. Defaults to True if .raw is present.
@@ -668,7 +668,7 @@ def batches_umap(
y_label: Optional[str] = 'umap2',
bfig_title: Optional[str] = 'all batches',
dot_size: Optional[int] = 1,
- colors: Optional[Union[str, list]] = 'stereo_30',
+ palette: Optional[Union[str, list]] = 'stereo_30',
width: Optional[int] = None,
height: Optional[int] = None
):
@@ -681,7 +681,7 @@ def batches_umap(
:param y_label: the y label.
:param bfig_title: the big figure title.
:param dot_size: the dot size.
- :param colors: the color list.
+ :param palette: the color list.
:param width: the figure width in pixels.
:param height: the figure height in pixels.
@@ -707,7 +707,7 @@ def batches_umap(
umap_res['batch'] = self.data.cells.batch.astype(np.uint16)
batch_number_unique = np.unique(umap_res['batch'])
batch_count = len(batch_number_unique)
- cmap = stereo_conf.get_colors(colors, batch_count)
+ cmap = stereo_conf.get_colors(palette, batch_count)
fig_all = umap_res.hvplot.scatter(
x='x', y='y', c='batch', cmap=cmap, cnorm='eq_hist',
).opts(
@@ -769,7 +769,6 @@ def umap(
x_label: Optional[Union[str, list]] = 'umap1',
y_label: Optional[Union[str, list]] = 'umap2',
dot_size: Optional[int] = None,
- colors: Optional[Union[str, list]] = 'stereo',
width: Optional[int] = None,
height: Optional[int] = None,
palette: Optional[int] = None,
@@ -787,7 +786,6 @@ def umap(
:param x_label: the x label.
:param y_label: the y label.
:param dot_size: the dot size.
- :param colors: the color list.
:param width: the figure width in pixels.
:param height: the figure height in pixels.
:param palette: color theme.
@@ -798,13 +796,15 @@ def umap(
""" # noqa
res = self.check_res_key(res_key)
+ if palette is None:
+ palette = 'stereo_30' if cluster_key else 'stereo'
if cluster_key:
cluster_res = self.check_res_key(cluster_key)
n = len(set(cluster_res['group']))
if title is None:
title = cluster_key
- if not palette:
- palette = stereo_conf.get_colors('stereo_30' if colors == 'stereo' else colors, n)
+ # if not palette:
+ # palette = stereo_conf.get_colors('stereo_30' if colors == 'stereo' else colors, n)
return base_scatter(
res.values[:, 0],
res.values[:, 1],
@@ -828,7 +828,7 @@ def umap(
res.values[:, 0],
res.values[:, 1],
hue=self.data.sub_exp_matrix_by_name(gene_name=gene_names).T,
- palette=colors,
+ palette=palette,
title=gene_names if title is None else title,
x_label=[x_label for i in range(len(gene_names))],
y_label=[y_label for i in range(len(gene_names))],
@@ -1298,28 +1298,38 @@ def scenic_clustermap(
@reorganize_coordinate
def cells_plotting(
self,
- color_by: Literal['total_count', 'n_genes_by_counts', 'gene', 'cluster'] = 'total_count',
+ color_by: Literal['total_counts', 'n_genes_by_counts', 'gene', 'cluster'] = 'total_counts',
color_key: Optional[str] = None,
bgcolor: Optional[str] = '#2F2F4F',
+ palette: Optional[Union[str, list, dict]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
fg_alpha: Optional[float] = 0.5,
base_image: Optional[str] = None,
- base_im_to_gray: bool = False
+ base_im_to_gray: bool = False,
+ use_raw: bool = True,
+ show: bool = True
):
"""Plot the cells.
- :param color_by: spcify the way of coloring, default to 'total_count'.
+ :param color_by: spcify the way of coloring, default to 'total_counts'.
if set to 'gene', you need to specify a gene name by `color_key`.
if set to 'cluster', you need to specify the key to get cluster result by `color_key`.
- :param color_key: the key to get the data to color the plot, it is ignored when the `color_by` is set to 'total_count' or 'n_genes_by_counts'.
+ :param color_key: the key to get the data to color the plot, it is ignored when the `color_by` is set to 'total_counts' or 'n_genes_by_counts'.
:param bgcolor: set background color.
+ :param palette: color theme,
+ when `color_by` is 'cluster', it can be a palette name, a list of colors or a dictionary of colors whose keys is the cluster name,
+ when other `color_by` is set, it only can be a string of palette name.
:param width: the figure width in pixels.
:param height: the figure height in pixels.
- :param fg_alpha: the alpha of foreground image, between 0 and 1, defaults to 0.5
+ :param fg_alpha: the transparency of foreground image, between 0 and 1, defaults to 0.5
this is the colored image of the cells.
:param base_image: the path of the ssdna image after calibration, defaults to None
it will be located behide the image of the cells.
+ :param base_im_to_gray: whether to convert the base image to gray scale if base image is RGB/RGBA image.
+ :param use_raw: whether to use raw data, defaults to True if .raw is present.
+ :param show: show the figure directly or get the figure object, defaults to True.
+ If set to False, you need to call the `show` method of the figure object to show the figure.
:param reorganize_coordinate: if the data is merged from several slices, whether to reorganize the coordinates of the obs(cells),
if set it to a number, like 2, the coordinates will be reorganized to 2 columns on coordinate system as below:
---------------
@@ -1331,7 +1341,28 @@ def cells_plotting(
if set it to `False`, the coordinates will not be changed.
:param horizontal_offset_additional: the additional offset between each slice on horizontal direction while reorganizing coordinates.
:param vertical_offset_additional: the additional offset between each slice on vertical direction while reorganizing coordinates.
- :return: Cells distribution figure.
+
+ .. note::
+
+ Exporting
+ ------------------
+
+ This plot can be exported as PNG and SVG, then converted to PDF.
+
+ You need the following necessary dependencies to support exporting:
+
+ conda install -c conda-forge selenium firefox geckodriver cairosvg
+
+ On Linux systems, you may need to install some additional libraries to support the above dependencies,
+ for example, on Ubuntu, you need to install the following libraries:
+
+ sudo apt-get install libgtk-3-dev libasound2-dev
+
+ On other linux systems, you may need to install the corresponding libraries according to the error message.
+
+ There are two ways to export the plot, one is to manupulate on browser,
+ another is to call the method `save_plot `_ of this figure object.
+
""" # noqa
from .plot_cells import PlotCells
if color_by in ('cluster', 'gene'):
@@ -1343,13 +1374,17 @@ def cells_plotting(
color_key=color_key,
# cluster_res_key=cluster_res_key,
bgcolor=bgcolor,
+ palette=palette,
width=width,
height=height,
fg_alpha=fg_alpha,
base_image=base_image,
- base_im_to_gray=base_im_to_gray
+ base_im_to_gray=base_im_to_gray,
+ use_raw=use_raw
)
- return pc.show()
+ if show:
+ return pc.show()
+ return pc
@download
def correlation_heatmap(
diff --git a/stereo/preprocess/qc.py b/stereo/preprocess/qc.py
index d61a2ae2..1b3ff3d0 100644
--- a/stereo/preprocess/qc.py
+++ b/stereo/preprocess/qc.py
@@ -27,7 +27,8 @@ def cal_cells_indicators(data):
exp_matrix = data.exp_matrix
data.cells.total_counts = cal_total_counts(exp_matrix)
data.cells.n_genes_by_counts = cal_n_genes_by_counts(exp_matrix)
- data.cells.pct_counts_mt = cal_pct_counts_mt(data)
+ # data.cells.pct_counts_mt = cal_pct_counts_mt(data)
+ data.cells.pct_counts_mt = cal_pct_counts_mt(exp_matrix, data.gene_names)
return data
@@ -35,7 +36,8 @@ def cal_genes_indicators(data):
exp_matrix = data.exp_matrix
data.genes.n_cells = cal_n_cells(exp_matrix)
data.genes.n_counts = cal_per_gene_counts(exp_matrix)
- data.genes.mean_umi = cal_gene_mean_umi(data)
+ # data.genes.mean_umi = cal_gene_mean_umi(data)
+ data.genes.mean_umi = cal_gene_mean_umi(exp_matrix)
return data
@@ -79,9 +81,20 @@ def cal_n_cells(exp_matrix):
return exp_matrix.getnnz(axis=0) if issparse(exp_matrix) else np.count_nonzero(exp_matrix, axis=0)
-def cal_gene_mean_umi(data):
+# def cal_gene_mean_umi(data):
+# old_settings = np.seterr(divide='ignore', invalid='ignore')
+# gene_mean_umi = data.genes.n_counts / data.genes.n_cells
+# flag = np.isnan(gene_mean_umi) | np.isinf(gene_mean_umi)
+# gene_mean_umi[flag] = 0
+# np.seterr(**old_settings)
+# return gene_mean_umi
+
+def cal_gene_mean_umi(exp_matrix):
old_settings = np.seterr(divide='ignore', invalid='ignore')
- gene_mean_umi = data.genes.n_counts / data.genes.n_cells
+ # gene_mean_umi = data.genes.n_counts / data.genes.n_cells
+ n_counts = cal_per_gene_counts(exp_matrix)
+ n_cells = cal_n_cells(exp_matrix)
+ gene_mean_umi = n_counts / n_cells
flag = np.isnan(gene_mean_umi) | np.isinf(gene_mean_umi)
gene_mean_umi[flag] = 0
np.seterr(**old_settings)
@@ -92,13 +105,24 @@ def cal_n_genes_by_counts(exp_matrix):
return exp_matrix.getnnz(axis=1) if issparse(exp_matrix) else np.count_nonzero(exp_matrix, axis=1)
-def cal_pct_counts_mt(data):
+# def cal_pct_counts_mt(data):
+# old_settings = np.seterr(divide='ignore', invalid='ignore')
+# if data.cells.total_counts is None:
+# data.cells.total_counts = cal_total_counts(data.exp_matrix)
+# mt_index = np.char.startswith(np.char.lower(data.gene_names), prefix='mt-')
+# mt_count = np.array(data.exp_matrix[:, mt_index].sum(1)).reshape(-1)
+# pct_counts_mt = mt_count / data.cells.total_counts * 100
+# flag = np.isnan(pct_counts_mt) | np.isinf(pct_counts_mt)
+# pct_counts_mt[flag] = 0
+# np.seterr(**old_settings)
+# return pct_counts_mt
+
+def cal_pct_counts_mt(exp_matrix, gene_names):
old_settings = np.seterr(divide='ignore', invalid='ignore')
- if data.cells.total_counts is None:
- data.cells.total_counts = cal_total_counts(data.exp_matrix)
- mt_index = np.char.startswith(np.char.lower(data.gene_names), prefix='mt-')
- mt_count = np.array(data.exp_matrix[:, mt_index].sum(1)).reshape(-1)
- pct_counts_mt = mt_count / data.cells.total_counts * 100
+ total_counts = cal_total_counts(exp_matrix)
+ mt_index = np.char.startswith(np.char.lower(gene_names), prefix='mt-')
+ mt_count = np.array(exp_matrix[:, mt_index].sum(1)).reshape(-1)
+ pct_counts_mt = mt_count / total_counts * 100
flag = np.isnan(pct_counts_mt) | np.isinf(pct_counts_mt)
pct_counts_mt[flag] = 0
np.seterr(**old_settings)