diff --git a/.gitignore b/.gitignore index 0f635c3c..b94ad25a 100644 --- a/.gitignore +++ b/.gitignore @@ -124,3 +124,6 @@ example_data # vscode *.code-workspace + +# dash widget +file_system_backend \ No newline at end of file diff --git a/element_calcium_imaging/plotting/draw_rois.py b/element_calcium_imaging/plotting/draw_rois.py new file mode 100644 index 00000000..0b8481d4 --- /dev/null +++ b/element_calcium_imaging/plotting/draw_rois.py @@ -0,0 +1,228 @@ +import yaml +import datajoint as dj +import numpy as np +import plotly.express as px +import plotly.graph_objects as go +from dash import no_update +from dash_extensions.enrich import ( + DashProxy, + Input, + Output, + State, + html, + dcc, + Serverside, + ServersideOutputTransform, +) + +from .utilities import * + + +logger = dj.logger + + +def draw_rois(db_prefix: str): + scan = dj.create_virtual_module("scan", f"{db_prefix}scan") + imaging = dj.create_virtual_module("imaging", f"{db_prefix}imaging") + all_keys = (imaging.MotionCorrection).fetch("KEY") + + colors = {"background": "#111111", "text": "#00a0df"} + + app = DashProxy(transforms=[ServersideOutputTransform()]) + app.layout = html.Div( + [ + html.H2("Draw ROIs", style={"color": colors["text"]}), + html.Label( + "Select data key from dropdown", style={"color": colors["text"]} + ), + dcc.Dropdown( + id="toplevel-dropdown", options=[str(key) for key in all_keys] + ), + html.Br(), + html.Div( + [ + html.Button( + "Load Image", + id="load-image-button", + style={"margin-right": "20px"}, + ), + dcc.RadioItems( + id="image-type-radio", + options=[ + {"label": "Average Image", "value": "average_image"}, + { + "label": "Max Projection Image", + "value": "max_projection_image", + }, + ], + value="average_image", + labelStyle={"display": "inline-block", "margin-right": "10px"}, + style={"display": "inline-block", "color": colors["text"]}, + ), + html.Div( + [ + html.Button("Submit Curated Masks", id="submit-button"), + ], + style={ + "textAlign": "right", + "flex": "1", + "display": "inline-block", + }, + ), + ], + style={ + "display": "flex", + "justify-content": "flex-start", + "align-items": "center", + }, + ), + html.Br(), + html.Br(), + html.Div( + [ + dcc.Graph( + id="avg-image", + config={ + "modeBarButtonsToAdd": [ + "drawclosedpath", + "drawrect", + "drawcircle", + "drawline", + "eraseshape", + ], + }, + style={"width": "100%", "height": "100%"}, + ) + ], + style={ + "display": "flex", + "justify-content": "center", + "align-items": "center", + "padding": "0.0", + "margin": "auto", + }, + ), + html.Pre(id="annotations"), + html.Div(id="button-output"), + dcc.Store(id="store-key"), + dcc.Store(id="store-mask"), + dcc.Store(id="store-movie"), + html.Div(id="submit-output"), + ] + ) + + @app.callback( + Output("store-key", "value"), + Input("toplevel-dropdown", "value"), + ) + def store_key(value): + if value is not None: + return Serverside(value) + else: + return no_update + + @app.callback( + Output("avg-image", "figure"), + Output("store-movie", "average_images"), + State("store-key", "value"), + Input("load-image-button", "n_clicks"), + Input("image-type-radio", "value"), + prevent_initial_call=True, + ) + def create_figure(value, render_n_clicks, image_type): + if render_n_clicks is not None: + if image_type == "average_image": + summary_images = ( + imaging.MotionCorrection.Summary & yaml.safe_load(value) + ).fetch("average_image") + else: + summary_images = ( + imaging.MotionCorrection.Summary & yaml.safe_load(value) + ).fetch("max_proj_image") + average_images = [image.astype("float") for image in summary_images] + roi_contours = get_contours(yaml.safe_load(value), db_prefix) + logger.info("Generating figure.") + fig = px.imshow( + np.asarray(average_images), + animation_frame=0, + binary_string=True, + labels=dict(animation_frame="plane"), + ) + for contour in roi_contours: + # Note: contour[:, 1] are x-coordinates, contour[:, 0] are y-coordinates + fig.add_trace( + go.Scatter( + x=contour[:, 1], # Plotly uses x, y order for coordinates + y=contour[:, 0], + mode="lines", # Display as lines (not markers) + line=dict(color="white", width=0.5), # Set line color and width + showlegend=False, # Do not show legend for each contour + ) + ) + fig.update_layout( + dragmode="drawrect", + autosize=True, + height=550, + newshape=dict(opacity=0.6, fillcolor="#00a0df"), + plot_bgcolor=colors["background"], + paper_bgcolor=colors["background"], + font_color=colors["text"], + ) + fig.update_annotations(bgcolor="#00a0df") + else: + return no_update + return fig, Serverside(average_images) + + @app.callback( + Output("store-mask", "annotation_list"), + Input("avg-image", "relayoutData"), + prevent_initial_call=True, + ) + def on_relayout(relayout_data): + if not relayout_data: + return no_update + else: + if "shapes" in relayout_data: + global shape_type + try: + shape_type = relayout_data["shapes"][-1]["type"] + return Serverside(relayout_data) + except IndexError: + return no_update + elif any(["shapes" in key for key in relayout_data]): + return Serverside(relayout_data) + + @app.callback( + Output("submit-output", "children"), + Input("submit-button", "n_clicks"), + State("store-mask", "annotation_list"), + State("store-key", "value"), + ) + def submit_annotations(n_clicks, annotation_list, value): + x_mask_li = [] + y_mask_li = [] + if n_clicks is not None: + if annotation_list: + if "shapes" in annotation_list: + logger.info("Creating Masks.") + shapes = [d["type"] for d in annotation_list["shapes"]] + for shape, annotation in zip(shapes, annotation_list["shapes"]): + mask = create_mask(annotation, shape) + y_mask_li.append(mask[0]) + x_mask_li.append(mask[1]) + print("Masks created") + insert_into_database( + scan, imaging, yaml.safe_load(value), x_mask_li, y_mask_li + ) + else: + logger.warn( + "Incorrect annotation list format. This is a known bug. Please draw a line anywhere on the image and click `Submit Curated Masks`. It will be ignored in the final submission but will format the list correctly." + ) + return no_update + else: + logger.warn("No annotations to submit.") + return no_update + else: + return no_update + + return app diff --git a/element_calcium_imaging/plotting/utilities.py b/element_calcium_imaging/plotting/utilities.py new file mode 100644 index 00000000..f1ab8fbb --- /dev/null +++ b/element_calcium_imaging/plotting/utilities.py @@ -0,0 +1,219 @@ +import pathlib +import datajoint as dj +import numpy as np +from scipy import ndimage +from skimage import draw, measure +from element_interface.utils import find_full_path + + +logger = dj.logger + + +def get_imaging_root_data_dir(): + """Retrieve imaging root data directory.""" + imaging_root_dirs = dj.config.get("custom", {}).get("imaging_root_data_dir", None) + if not imaging_root_dirs: + return None + elif isinstance(imaging_root_dirs, (str, pathlib.Path)): + return [imaging_root_dirs] + elif isinstance(imaging_root_dirs, list): + return imaging_root_dirs + else: + raise TypeError("`imaging_root_data_dir` must be a string, pathlib, or list") + + +def path_to_indices(path): + """From SVG path to numpy array of coordinates, each row being a (row, col) point""" + indices_str = [ + el.replace("M", "").replace("Z", "").split(",") for el in path.split("L") + ] + return np.rint(np.array(indices_str, dtype=float)).astype(int) + + +def path_to_mask(path, shape): + """From SVG path to a boolean array where all pixels enclosed by the path + are True, and the other pixels are False. + """ + cols, rows = path_to_indices(path).T + rr, cc = draw.polygon(rows, cols) + mask = np.zeros(shape, dtype=bool) + mask[rr, cc] = True + mask = ndimage.binary_fill_holes(mask) + return mask + + +def create_ellipse_mask(vertices, image_shape): + """ + Create a mask for an ellipse given its vertices. + + :param vertices: Tuple of (x0, y0, x1, y1) representing the bounding box of the ellipse. + :param image_shape: Shape of the image (height, width) to create a mask for. + :return: Binary mask with the ellipse. + """ + x0, x1, y0, y1 = vertices + center = ((x0 + x1) / 2, (y0 + y1) / 2) + axis_lengths = (abs(x1 - x0) / 2, abs(y1 - y0) / 2) + + rr, cc = draw.ellipse( + center[1], center[0], axis_lengths[1], axis_lengths[0], shape=image_shape + ) + mask = np.zeros(image_shape, dtype=np.bool_) + mask[rr, cc] = True + mask = ndimage.binary_fill_holes(mask) + + return mask + + +def create_rectangle_mask(vertices, image_shape): + """ + Create a mask for a rectangle given its vertices. + + :param vertices: Tuple of (x0, y0, x1, y1) representing the top-left and bottom-right corners of the rectangle. + :param image_shape: Shape of the image (height, width) to create a mask for. + :return: Binary mask with the rectangle. + """ + x0, x1, y0, y1 = vertices + rr, cc = draw.rectangle(start=(y0, x0), end=(y1, x1), shape=image_shape) + mask = np.zeros(image_shape, dtype=np.bool_) + mask[rr, cc] = True + mask = ndimage.binary_fill_holes(mask) + + return mask + + +def create_mask(coordinates, shape_type): + if shape_type == "path": + try: + mask = np.asarray(path_to_mask(coordinates["path"], (512, 512))).nonzero() + except KeyError: + for key, info in coordinates.items(): + mask = np.asarray(path_to_mask(info, (512, 512))).nonzero() + + elif shape_type == "circle": + try: + mask = np.asarray( + create_ellipse_mask( + [ + int(coordinates["x0"]), + int(coordinates["x1"]), + int(coordinates["y0"]), + int(coordinates["y1"]), + ], + (512, 512), + ) + ).nonzero() + except KeyError: + xy_coordinates = np.asarray( + [item for item in coordinates.values()], dtype="int" + ) + mask = np.asarray(create_ellipse_mask(xy_coordinates, (512, 512))).nonzero() + elif shape_type == "rect": + try: + mask = np.asarray( + create_rectangle_mask( + [ + int(coordinates["x0"]), + int(coordinates["x1"]), + int(coordinates["y0"]), + int(coordinates["y1"]), + ], + (512, 512), + ) + ).nonzero() + except KeyError: + xy_coordinates = np.asarray( + [item for item in coordinates.values()], dtype="int" + ) + mask = np.asarray( + create_rectangle_mask(xy_coordinates, (512, 512)) + ).nonzero() + elif shape_type == "line": + try: + mask = np.array( + ( + int(coordinates["x0"]), + int(coordinates["x1"]), + int(coordinates["y0"]), + int(coordinates["y1"]), + ) + ) + except KeyError: + mask = np.asarray([item for item in coordinates.values()], dtype="int") + return mask + + +def get_contours(image_key, prefix): + scan = dj.create_virtual_module("scan", f"{prefix}scan") + imaging = dj.create_virtual_module("imaging", f"{prefix}imaging") + yshape, xshape = (scan.ScanInfo.Field & image_key).fetch1("px_height", "px_width") + mask_xpix, mask_ypix = (imaging.Segmentation.Mask & image_key).fetch( + "mask_xpix", "mask_ypix" + ) + mask_image = np.zeros((yshape, xshape), dtype=bool) + for xpix, ypix in zip(mask_xpix, mask_ypix): + mask_image[ypix, xpix] = True + contours = measure.find_contours(mask_image.astype(float), 0.5) + return contours + + +def load_imaging_data_for_session(scan, key): + image_files = (scan.ScanInfo.ScanFile & key).fetch("file_path") + image_files = [ + find_full_path(get_imaging_root_data_dir(), image_file) + for image_file in image_files + ] + acq_software = (scan.Scan & key).fetch1("acq_software") + if acq_software == "ScanImage": + import tifffile + + imaging_data = tifffile.imread(image_files[0]) + elif acq_software == "NIS": + import nd2 + + imaging_data = nd2.imread(image_files[0]) + else: + raise ValueError( + f"Support for images with acquisition software: {acq_software} is not yet implemented into the widget." + ) + return imaging_data + + +def insert_into_database(scan_module, imaging_module, session_key, x_masks, y_masks): + images = load_imaging_data_for_session(scan_module, session_key) + mask_id = (imaging_module.Segmentation.Mask & session_key).fetch( + "mask", order_by="mask desc", limit=1 + ) + logger.info(f"Inserting {len(x_masks)} masks into the database.") + imaging_module.Segmentation.Mask.insert( + [ + dict( + **session_key, + mask=mask_id + mask_num, + segmentation_channel=1, + mask_npix=y_mask.shape[0], + mask_center_x=int(sum(x_mask) / x_mask.shape[0]), + mask_center_y=int(sum(y_mask) / y_mask.shape[0]), + mask_center_z=0, + mask_xpix=x_mask, + mask_ypix=y_mask, + mask_zpix=0, + mask_weights=np.ones_like(y_mask), + ) + for mask_num, (x_mask, y_mask) in enumerate(zip(x_masks, y_masks)) + ], + allow_direct_insert=True, + ) + logger.info(f"Inserting {len(x_masks)} traces into the database.") + imaging_module.Fluorescence.Trace.insert( + [ + dict( + **session_key, + mask=mask_id + mask_num, + fluo_channel=1, + fluorescence=images[:, y_mask, x_mask].mean(axis=1), + ) + for mask_num, (x_mask, y_mask) in enumerate(zip(x_masks, y_masks)) + ], + allow_direct_insert=True, + ) + logger.info("Inserts complete.") diff --git a/setup.py b/setup.py index 2586295c..4dfed13a 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,8 @@ "ipykernel>=6.0.1", "ipywidgets", "plotly", + "dash-extensions", + "scikit-image", "element-interface @ git+https://github.com/datajoint/element-interface.git", ], extras_require={