Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GOAL] Demo viewing of large multi-chan timeseries data with multi-time-resolution generator and dynamic accessor #87

Closed
2 tasks done
droumis opened this issue Jan 5, 2024 · 49 comments

Comments

@droumis
Copy link
Collaborator

droumis commented Jan 5, 2024

Problem:

On their own, our current methods like Datashader and downsampling are insufficient for data that cannot be fully loaded into memory.

Description/Solution/Goals:

This project aims to enable effective processing and visualization of biological datasets that exceed available memory limits. The task is to develop a proof of concept for an xarray-datatree-based multi-resolution generator and dynamic accessor. This involves generating and storing incrementally downsampled versions of a large dataset, and then accessing the appropriate resolution copy based on viewport and screen parameters. We want to leverage existing work and standards as much as possible, aligning with the geo and bio communities.

Potential Methods and Tools to Leverage:

Tasks:

  • 1. Research xarray-datatree storage conventions in zarr, summarize the options here and then write a notebook that takes the ephys data and uses the downsample1d operation (or directly use tsdownsampler) to generate a hierarchical tree of downsampled versions of the data.
  • 2. Write a notebook that uses xarray-datatree and a DynamicMap to load based on zoom level
  1. Implement a data interface in HoloViews that wraps xarray-datatree and loads the appropriate subset of data automatically given a configurable max data size.
  2. Repeat the relevant steps above for the microscopy data use case and dataset.

Use-Cases, Starter Viz Code, and Datasets:

Stacked timeseries:

Summary

  • A stacked timeseries plot is commonly used for synchronized examination of time-aligned, amplitude/unit-diverse timeseries. A useful stacked timeseries plot might look like this:
Code
from scipy.stats import zscore
import h5py

import holoviews as hv; hv.extension('bokeh')
from holoviews.plotting.links import RangeToolLink
from holoviews.operation.datashader import rasterize
from bokeh.models import HoverTool

filename = 'recording_neuropixels_10s_384ch.h5'
f = h5py.File(filename, "r")

n_sample_chans = 40
n_sample_times = 25000 # sampling frequency is 25 kHz
clim_mul = 2

# main plot
hover = HoverTool(tooltips=[
    ("Channel", "@channel"),
    ("Time", "$x s"),
    ("Amplitude", "$y µV")])

time = f['timestamps'][:n_sample_times]
data = f['recordings'][:n_sample_times,:n_sample_chans].T

f.close()

channels = [f'ch{i}' for i in range(n_sample_chans)]
channels = channels[:n_sample_chans]

channel_curves = []
for i, channel in enumerate(channels):
    ds = hv.Dataset((time, data[i,:], channel), ["Time", "Amplitude", "channel"])
    curve = hv.Curve(ds, "Time", ["Amplitude", "channel"], label=f'{channel}')
    curve.opts(color="black", line_width=1, subcoordinate_y=True, subcoordinate_scale=3, tools=[hover])
    channel_curves.append(curve)

curves = hv.Overlay(channel_curves, kdims="Channel")

curves = curves.opts(
    xlabel="Time (s)", ylabel="Channel", show_legend=False,
    padding=0, aspect=1.5, responsive=True, shared_axes=False, framewise=False)

# minimap
y_positions = range(len(channels))
yticks = [(i, ich) for i, ich in enumerate(channels)]
z_data = zscore(data, axis=1)

minimap = rasterize(hv.Image((time, y_positions, z_data), ["Time (s)", "Channel"], "Amplitude (uV)"))
minimap = minimap.opts(
    cmap="RdBu_r", colorbar=False, xlabel='', yticks=[yticks[0], yticks[-1]], toolbar='disable',
    height=120, responsive=True, clim=(-z_data.std()*clim_mul, z_data.std()*clim_mul))

RangeToolLink(minimap, curves, axes=["x", "y"],
              boundsx=(.1, .3),
              boundsy=(10, 30))

(curves + minimap).cols(1)
  • The main interaction is zooming and panning through time (x) or channels (y). A primary goal of this multi-res initiative is to make this interaction responsive, regardless of the dataset size.
  • In addition to a main plot of the actual timeseries (dims: time, source/channel), it is beneficial to utilize a minimap/RangeToolLink image to be able to get an impression of the whole dataset and navigate the viewport of the main plot to various regions of interest. For example, see the minimap plot on the bottom in the image above - the image is a rasterized version of the entire dataset (at least the chunk that I chose to work with for this demo) with the same x and y dims as the main plot. This particular minimap image is zscored per channel/source to normalize the amplitude across channels and facilitate pattern detection (although it's not necessary for this demo dataset that has been bandpass filtered).

Data

  • The datasets below are simulated multielectrode electrophysiological (ephys) data, saved to .h5 (a common underlying format for ephys data). They were created with this nb.
    • Larger Simulated Ephys data (5,000,000 time samples (200s), 384 channels) - 15 GB:
      • datasets.holoviz.org/ephys_sim/v1/ephys_sim_neuropixels_200s_384ch.h5
    • Smaller Simulated Ephys data (250,000 time samples (10s), 384 channels) - 3 GB:
      • datasets.holoviz.org/ephys_sim/v1/ephys_sim_neuropixels_10s_384ch.h5

Note... I recommend working through this notebook on accessing ephys HDF5 Datasets into xarray via Kerchunk and Zarr that Ian created. I can imagine a situation in which the approach to a multiresolution access just utilizes kerchunk references instead of downsampled data copies; although I'm not sure how that would work with xarray-datatree - maybe it would have to be either kerchunk or xarray-datatree, but not both. Maybe we could consult Martin.

Miniscope Image Stack: UPDATE: solved without needing multi-res handling

Summary

  • A miniscope image stack typically has a modest height and width resolution but a deep time/frame dimension. A useful miniscope image stack viewer might look like this:
Code
import xarray as xr
import panel as pn; pn.extension()
import holoviews as hv; hv.extension('bokeh')
import hvplot.xarray

DATA_ARRAY = '1000frames'

DATA_PATH = f"<miniscope_sim_{DATA_ARRAY}.zarr>"

ldataset = xr.open_dataset(DATA_PATH, engine='zarr', chunks='auto')

data = ldataset[DATA_ARRAY]

# data.hvplot.image(groupby="frame", cmap="Viridis", height=400, width=400, colorbar=False)

FRAMES_PER_SECOND = 30
FRAMES = data.coords["frame"].values

# Create a video player widget
video_player = pn.widgets.Player(
    length=len(data.coords["frame"]),
    interval=1000 // FRAMES_PER_SECOND,  # ms
    value=int(FRAMES.min()),
    max_width=400,
    max_height=90,
    loop_policy="loop",
    sizing_mode="stretch_width",
)

# Create the main plot
main_plot = data.hvplot.image(
    groupby="frame",
    cmap="Viridis",
    frame_height=400,
    frame_width=400,
    colorbar=False,
    widgets={"frame": video_player},
)

# frame indicator lines on side plots
line_opts = dict(color='red', alpha=.6, line_width=3)
dmap_hline = hv.DynamicMap(pn.bind(lambda value: hv.HLine(value), video_player)).opts(**line_opts)
dmap_vline = hv.DynamicMap(pn.bind(lambda value: hv.VLine(value), video_player)).opts(**line_opts)

# height side view
right_plot = data.mean(['width']).hvplot.image(x='frame',
    cmap="Viridis",
    frame_height=400,
    frame_width=200,
    colorbar=False,
    rasterize=True,
    title='_', # TODO: Fix this. See https://github.com/bokeh/bokeh/issues/13225#issuecomment-1611172355
) * dmap_vline

# width side view
bottom_plot = data.mean(['height']).hvplot.image(y='frame',
    cmap="Viridis",
    frame_height=200,
    frame_width=400,
    colorbar=False,
    rasterize=True,
) * dmap_hline

video_player.margin = (20, 20, 20, 70) # center widget over main

sim_app = pn.Column(
    video_player,
    pn.Row(main_plot[0], right_plot),
    bottom_plot)

sim_app
  • The main interaction is scrubbing or playing through the time/frames. A primary goal of this multi-res initiative is to make this scrubbing/playing responsive, regardless of the dataset size.
  • In addition to a main plot of a single time/frame (dims: height, width), it is beneficial to see 2D side-views of the image stack cube where either the width or height dimension is aggregated. For example, see the plot on the right in the image above - the width dimension is max-aggragated, and it shows the progression of height values over times/frames. Another primary goal of this multi-res initiative is to be able to render and display these side plots, regardless of the dataset size.

Data

  • The datasets below are simulated miniscope data, chunked in the time/frame dimension, saved to [zarr via xarray](xarray-augmented zarr format). They were created with this script which runs code from here.
    • Larger Simulated Miniscope data (512 height, 512 width, 10,000 frames) - 24 GB:
      • datasets.holoviz.org/sim_miniscope/v1/miniscope_sim_10000frames.zarr
    • Smaller Simulated Miniscope data (512 height, 512 width, 1,000 frames) - 2.4 GB:
      • datasets.holoviz.org/sim_miniscope/v1/miniscope_sim_1000frames.zarr

Additional Notes and Resources:

@droumis droumis changed the title [POC] Development of Xarray-Datatree-Based Multi-Resolution Generator and Dynamic Accessor POC Development of Xarray-Datatree-Based Multi-Resolution Generator and Dynamic Accessor Jan 5, 2024
@droumis droumis changed the title POC Development of Xarray-Datatree-Based Multi-Resolution Generator and Dynamic Accessor POC Xarray-Datatree-Based Multi-Resolution Generator and Dynamic Accessor Jan 5, 2024
@droumis droumis moved this to Todo in CZI R5 neuro Jan 9, 2024
@d-v-b
Copy link

d-v-b commented Jan 10, 2024

a bit stale but maybe relevant: xarray-multiscale

Just to point out, xarray-multiscale looks stale because nobody has asked for anything new, and it works :) Let me know if there's anything you'd like to see added there!

@ahuang11
Copy link
Collaborator

@philippjfr
Copy link
Collaborator

philippjfr commented Jan 24, 2024

To kick this off I suggest let's break this into three tasks:

  1. Research xarray-datatree storage conventions in zarr, summarize the options here and then write a notebook that takes the ephys data and uses the downsample1d operation to generate a hierarchical tree of downsampled versions of the data and persists that to zarr using some canonical format.
  2. Write a notebook that uses xarray-datatree and a DynamicMap to load based on zoom level
  3. Implement a data interface in HoloViews that wraps xarray-datatree and loads the appropriate subset of data automatically given a configurable max data size.

I would not use the miniscope data for anything here as that actually does not need xarray-datatree approach at all, each 512x512 image fits in memory just fine, so as long as they can be loaded independently and lazily we already have no problem handling this data. If we do want some data we can't currently handle we'd need to collect some huge EM imagery where each individual frame has to be downsampled ahead of time to be workable.

For now I'd suggest @ahuang11 start with task 1 and then let's revisit at the next CZI meeting.

@philippjfr
Copy link
Collaborator

philippjfr commented Jan 24, 2024

Oh I suppose the part I missed about the microscopy imagery dataset is that you indeed might want a time cross-section so it might still be useful to resample it along the time dimension, but let's start with the ephys dataset anyway.

@droumis
Copy link
Collaborator Author

droumis commented Jan 24, 2024

💯 Thanks @philippjfr! I updated the issue description in accordance with your suggested task priority list.

@ahuang11, sound good?

@philippjfr
Copy link
Collaborator

Note if you don't want to use the downsample1d operation it should be straightforward to just use the tsdownsample library directly.

@ahuang11
Copy link
Collaborator

ahuang11 commented Jan 24, 2024

That sounds good to me. The only thing is that I can't guarantee I'll have time to get to this by next Tuesday (depends on how successful / straightforward the other task I have is); I'll try though!

@philippjfr
Copy link
Collaborator

No worries it's not a super high priority.

@d-v-b
Copy link

d-v-b commented Jan 24, 2024

If we do want some data we can't currently handle we'd need to collect some huge EM imagery where each individual frame has to be downsampled ahead of time to be workable.

I work with that kind of data! Here's an example: (neuroglancer link)

I have a lot of data like this, and I would love to be able to browse it from a jupyter notebook. In fact, there is (to my knowledge) no python solution for browsing this data in an acceptable way. I'd be super excited to try anything you make on some of our datasets.

@ahuang11
Copy link
Collaborator

ahuang11 commented Feb 1, 2024

Okay I got around to kicking this off, starting with task 1

First off, the following code is heavily adapted from https://github.com/carbonplan/ndpyramid/blob/main/ndpyramid/core.py

All I did was make it accept a different function (instead of ds.coarsen). Reading the Zarr conventions for multiscale pyramids https://forum.image.sc/t/multiscale-arrays-v0-1/37930, CarbonPlan did a great job with multiscales_template so I copied it 99%. I also submitted an issue carbonplan/ndpyramid#94 to request generalization of the function so I don't have to copy/paste

The only thing I think that could be invalid is: "The paths to the arrays in dataset series MUST be ordered from largest (i.e. highest resolution) to smallest." In pyramid_coarsen factors doesn't sort, so in the example, pyramid_coarsen(ds, factors=[16, 8, 4, 3, 2, 1], dims=['lat', 'lon'], boundary='trim'), I think the factors should be reversed and validated in the function (I added factors = sorted(factors)) carbonplan/ndpyramid#95

Anyhow, I think this accomplishes task 1 (for 1D data).

I will work on task 2 tomorrow; let me know if you have any concerns before I proceed!

import h5py
import xarray as xr
import datatree as dt
from tsdownsample import MinMaxLTTBDownsampler

def downsample(ds, n_out):
    time_index = MinMaxLTTBDownsampler().downsample(ds["time"], ds["data"], n_out=n_out)
    return ds.isel(time=time_index)


# adapted from https://github.com/carbonplan/ndpyramid/blob/main/ndpyramid/core.py
def multiscales_template(
    *,
    datasets: list = None,
    type: str = "",
    method: str = "",
    version: str = "",
    args: list = None,
    kwargs: dict = None,
):
    if datasets is None:
        datasets = []
    if args is None:
        args = []
    if kwargs is None:
        kwargs = {}
    # https://forum.image.sc/t/multiscale-arrays-v0-1/37930
    return [
        {
            "datasets": datasets,
            "type": type,
            "metadata": {
                "method": method,
                "version": version,
                "args": args,
                "kwargs": kwargs,
            },
        }
    ]


def pyramid_downsample(
    ds: xr.Dataset, *, factors: list[int], **kwargs
) -> dt.DataTree:
    """Create a multiscale pyramid via coarsening of a dataset by given factors

    Parameters
    ----------
    ds : xarray.Dataset
        The dataset to coarsen.
    factors : list[int]
        The factors to coarsen by.
    kwargs : dict
        Additional keyword arguments to pass to xarray.Dataset.coarsen.
    """
    factors = sorted(factors)

    # multiscales spec
    save_kwargs = locals()
    del save_kwargs["ds"]

    attrs = {
        "multiscales": multiscales_template(
            datasets=[{"path": str(i)} for i in range(len(factors))],
            type="pick",
            method="pyramid_downsample",
            version="0.1",
            kwargs=save_kwargs,
        )
    }

    # set up pyramid
    plevels = {}

    # pyramid data
    for key, factor in enumerate(factors):
        factor += 1
        result = downsample(ds, len(ds["data"]) // factor)
        plevels[str(key)] = result

    plevels["/"] = xr.Dataset(attrs=attrs)
    return dt.DataTree.from_dict(plevels)

h5_f = h5py.File("allensdk_cache/session_715093703/session_715093703.nwb", "r")

times = h5_f["acquisition"]["raw_running_wheel_rotation"]["timestamps"]
data = h5_f["acquisition"]["raw_running_wheel_rotation"]["data"]
ts_ds = xr.DataArray(data, coords={"time": times}, dims=["time"], name="data").to_dataset()

ts_dt = pyramid_downsample(ts_ds, factors=[0, 1, 2, 4, 8])
ts_dt.to_zarr("timeseries.zarr", mode="w")

dt.open_datatree("timeseries.zarr", engine="zarr")
image

@droumis
Copy link
Collaborator Author

droumis commented Feb 1, 2024

Hey @ahuang11 , looking great so far!! Can you test it on the electrophysiology data? I don't imagine it would be any different, but the raw_running_wheel_rotation data in your screenshot is simpler and lower res behavior-related timeseries than the electrophysiology. The ephys is just in a different path of the .nwb file.

It would be so great to see it working for both this real allensdk LFP data as well as the simulated data mentioned in this issue's description (e.g. datasets.holoviz.org/ephys_sim/v1/ephys_sim_neuropixels_200s_384ch.h5).

@ahuang11
Copy link
Collaborator

ahuang11 commented Feb 1, 2024

Is it running wheel signal voltage?

KeysViewHDF5 ['acquisition', 'analysis', 'file_create_date', 'general', 'identifier', 'intervals', 'processing', 'session_description', 'session_start_time', 'specifications', 'stimulus', 'timestamps_reference_time', 'units']>

h5_f["acquisition"].keys()
<KeysViewHDF5 ['raw_running_wheel_rotation', 'running_wheel_signal_voltage', 'running_wheel_supply_voltage']>

@droumis
Copy link
Collaborator Author

droumis commented Feb 1, 2024

Sorry, the probe LFP file is actually linked from within that 'session' file. Ian's nb shows how to grab the probe LFP file data.

probe_id = session.probes.index.values[0]
lfp = session.get_lfp(probe_id) # This will download 2 GB of LFP data

The NWB file that we want is now stored locally. We need its filename to read it directly so that we don't have to use the AllenSDK any more.

lfp_nwb_filename = os.path.join(local_cache_dir, f"session_{session_id}", f"probe_{probe_id}_lfp.nwb")
f = h5py.File(lfp_nwb_filename, "r")
lfp_data = f[f"acquisition/probe_{probe_id}_lfp/probe_{probe_id}_lfp_data/data"]
image

@ahuang11
Copy link
Collaborator

ahuang11 commented Feb 1, 2024

Thanks! I realized the line that downloads that was commented out that's why I couldn't find it # lfp = session.get_lfp(probe_id) # This will load 2 GB of LFP data

Anyhow, I was able to use the LFP data and generate pyramids for it. I was wondering though, should MinMaxLTTBDownsampler().downsample be applied across all channels, or each individual ones? I think all channels or else the times mismatch and it ends up being a sparse array?

image

Lastly, a couple issues I couldn't resolve today:

  1. I can't figure out how to add back the time coordinates without returning a 2D array ts_ds_downsampled["time"] = ts_ds["time"].isel(time=indices.values[0])
  2. If I don't ts_ds = ts_ds.load() it crashes with KeyError: dim1

MRVE questions: https://discourse.pangeo.io/t/return-a-3d-object-alongside-1d-object-in-apply-ufunc/4008

Code
import h5py
import xarray as xr
import dask.array as da
import datatree as dt
from tsdownsample import MinMaxLTTBDownsampler


def _help_downsample(data, time, n_out):
    indices = MinMaxLTTBDownsampler().downsample(time, data, n_out=n_out)
    return data[indices], indices


def apply_downsample(ts_ds, n_out):
    ts_ds = ts_ds.load()
    ts_ds_downsampled, indices = xr.apply_ufunc(
        _help_downsample,
        ts_ds["data"],
        ts_ds["time"],
        kwargs=dict(n_out=n_out),
        input_core_dims=[["time"], ["time"]],
        output_core_dims=[["time"], ["indices"]],
        exclude_dims=set(("time",)),
        vectorize=True,
        dask="parallelized",
        output_sizes={"indices": n_out},
    )
    ts_ds_downsampled["time"] = ts_ds["time"].isel(time=indices.values[0])
    return ts_ds_downsampled.rename("data")


def build_dataset(f, data_key, dims):
    coords = {f[dim] for dim in dims.values()}
    data = f[data_key]
    ds = xr.DataArray(
        da.from_array(data, name="data", chunks=(data.shape[0], 1)),
        dims=dims,
        coords=coords,
    ).to_dataset()
    return ds


# adapted from https://github.com/carbonplan/ndpyramid/blob/main/ndpyramid/core.py
def multiscales_template(
    *,
    datasets: list = None,
    type: str = "",
    method: str = "",
    version: str = "",
    args: list = None,
    kwargs: dict = None,
):
    if datasets is None:
        datasets = []
    if args is None:
        args = []
    if kwargs is None:
        kwargs = {}
    # https://forum.image.sc/t/multiscale-arrays-v0-1/37930
    return [
        {
            "datasets": datasets,
            "type": type,
            "metadata": {
                "method": method,
                "version": version,
                "args": args,
                "kwargs": kwargs,
            },
        }
    ]


def pyramid_downsample(ds: xr.Dataset, *, factors: list[int], **kwargs) -> dt.DataTree:
    """Create a multiscale pyramid via coarsening of a dataset by given factors

    Parameters
    ----------
    ds : xarray.Dataset
        The dataset to coarsen.
    factors : list[int]
        The factors to coarsen by.
    kwargs : dict
        Additional keyword arguments to pass to xarray.Dataset.coarsen.
    """

    # multiscales spec
    save_kwargs = locals()
    del save_kwargs["ds"]

    attrs = {
        "multiscales": multiscales_template(
            datasets=[{"path": str(i)} for i in range(len(factors))],
            type="pick",
            method="pyramid_downsample",
            version="0.1",
            kwargs=save_kwargs,
        )
    }

    # set up pyramid
    plevels = {}

    # pyramid data
    for key, factor in enumerate(factors):
        factor += 1
        result = apply_downsample(ds, len(ds["data"]) // factor)
        plevels[str(key)] = result

    plevels["/"] = xr.Dataset(attrs=attrs)
    return dt.DataTree.from_dict(plevels)

f = h5py.File("allensdk_cache/session_715093703/probe_810755797_lfp.nwb", "r")
ts_ds = build_dataset(
    f,
    "acquisition/probe_810755797_lfp_data/data",
    {
        "time": "acquisition/probe_810755797_lfp_data/timestamps",
        "channel": "acquisition/probe_810755797_lfp_data/electrodes",
    },
)
ts_dt = pyramid_downsample(ts_ds, factors=[0, 1, 2, 4, 8])

f = h5py.File("allensdk_cache/session_715093703/probe_810755797_lfp.nwb", "r")
ts_ds = build_dataset(
    f,
    "acquisition/probe_810755797_lfp_data/data",
    {
        "time": "acquisition/probe_810755797_lfp_data/timestamps",
        "channel": "acquisition/probe_810755797_lfp_data/electrodes",
    },
)

ts_dt = pyramid_downsample(ts_ds, factors=[0, 1, 2, 4, 8])

ts_dt.to_zarr("timeseries.zarr", mode="w")

dt.open_datatree("timeseries.zarr", engine="zarr")

@ahuang11
Copy link
Collaborator

ahuang11 commented Feb 2, 2024

For task 2, load based on zoom level:

Screen.Recording.2024-02-01.at.10.45.02.PM.mov
import numpy as np
import holoviews as hv
from holoviews.operation.datashader import datashade
hv.extension("bokeh")

def rescale(x_range):
    nlevels = len(ts_dt)
    sub_ds = ts_dt[str(nlevels - 1)].isel(channel=0).ds
    if x_range:
        x_slice = slice(*x_range)
        subset_length = ts_dt["0"].sel(time=x_slice)["time"].size
        level = str(nlevels - np.argmin(np.abs(lengths - subset_length)) - 1)
        sub_ds = ts_dt[level].sel(channel=0, time=x_slice).ds
        print(f"Using {level} for {subset_length} samples")
    return hv.Curve(sub_ds, ["time"], ["data"])


range_stream = hv.streams.RangeX()
lengths = np.array([ts_dt[f"{i}"]["time"].size for i in range(len(ts_dt))])
dmap = hv.DynamicMap(rescale, streams=[range_stream])
dmap

@ahuang11
Copy link
Collaborator

ahuang11 commented Feb 6, 2024

The latest code, although it has been brought to my attention that apply_ufunc might not be needed here.

So, I'm wondering whether each channel should be downsampled individually (looped) or together (stacked)?

import h5py
import xarray as xr
import dask.array as da
import datatree as dt
from tsdownsample import MinMaxLTTBDownsampler


def _help_downsample(data, time, n_out):
    indices = MinMaxLTTBDownsampler().downsample(time, data, n_out=n_out)
    return data[indices], indices


def apply_downsample(ts_ds, n_out):
    ts_ds_downsampled, indices = xr.apply_ufunc(
        _help_downsample,
        ts_ds["data"],
        ts_ds["time"],
        kwargs=dict(n_out=n_out),
        input_core_dims=[["time"], ["time"]],
        output_core_dims=[["time"], ["indices"]],
        exclude_dims=set(("time",)),
        vectorize=True,
        dask="parallelized",
        dask_gufunc_kwargs=dict(output_sizes={"time": n_out, "indices": n_out}),
    )
    print(indices)
    print(indices[0])
    ts_ds_downsampled["time"] = ts_ds["time"].isel(time=indices.values[0])
    return ts_ds_downsampled.rename("data")


def build_dataset(f, data_key, dims):
    coords = {f[dim] for dim in dims.values()}
    data = f[data_key]
    ds = xr.DataArray(
        da.from_array(data, name="data", chunks=(data.shape[0], 1)),
        dims=dims,
        coords=coords,
    ).to_dataset()
    return ds


# adapted from https://github.com/carbonplan/ndpyramid/blob/main/ndpyramid/core.py
def multiscales_template(
    *,
    datasets: list = None,
    type: str = "",
    method: str = "",
    version: str = "",
    args: list = None,
    kwargs: dict = None,
):
    if datasets is None:
        datasets = []
    if args is None:
        args = []
    if kwargs is None:
        kwargs = {}
    # https://forum.image.sc/t/multiscale-arrays-v0-1/37930
    return [
        {
            "datasets": datasets,
            "type": type,
            "metadata": {
                "method": method,
                "version": version,
                "args": args,
                "kwargs": kwargs,
            },
        }
    ]


def pyramid_downsample(ds: xr.Dataset, *, factors: list[int], **kwargs) -> dt.DataTree:
    """Create a multiscale pyramid via coarsening of a dataset by given factors

    Parameters
    ----------
    ds : xarray.Dataset
        The dataset to coarsen.
    factors : list[int]
        The factors to coarsen by.
    kwargs : dict
        Additional keyword arguments to pass to xarray.Dataset.coarsen.
    """

    # multiscales spec
    save_kwargs = locals()
    del save_kwargs["ds"]

    attrs = {
        "multiscales": multiscales_template(
            datasets=[{"path": str(i)} for i in range(len(factors))],
            type="pick",
            method="pyramid_downsample",
            version="0.1",
            kwargs=save_kwargs,
        )
    }

    # set up pyramid
    plevels = {}

    # pyramid data
    for key, factor in enumerate(factors):
        factor += 1
        result = apply_downsample(ds, len(ds["data"]) // factor)
        plevels[str(key)] = result

    plevels["/"] = xr.Dataset(attrs=attrs)
    return dt.DataTree.from_dict(plevels)

f = h5py.File("allensdk_cache/session_715093703/probe_810755797_lfp.nwb", "r")
ts_ds = build_dataset(
    f,
    "acquisition/probe_810755797_lfp_data/data",
    {
        "time": "acquisition/probe_810755797_lfp_data/timestamps",
        "channel": "acquisition/probe_810755797_lfp_data/electrodes",
    },
).isel(channel=[0, 1, 2, 3, 4])
ts_dt = pyramid_downsample(ts_ds, factors=[0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512])

ts_dt.to_zarr("timeseries.zarr", mode="w")

ts_dt = dt.open_datatree("timeseries.zarr", engine="zarr")

import numpy as np
import holoviews as hv
from holoviews.operation.datashader import datashade
hv.extension("bokeh")

def rescale(x_range):
    nlevels = len(ts_dt)
    sub_ds = ts_dt[str(nlevels - 1)].isel(channel=0).ds
    if x_range:
        x_slice = slice(*x_range)
        subset_length = ts_dt["0"].sel(time=x_slice)["time"].size
        level = str(nlevels - np.argmin(np.abs(lengths - subset_length)) - 1)
        sub_ds = ts_dt[level].sel(channel=0, time=x_slice).ds
        print(f"Using {level} {lengths[level]} for {subset_length} samples")
    return hv.Curve(sub_ds, ["time"], ["data"])


range_stream = hv.streams.RangeX()
lengths = np.array([ts_dt[f"{i}"]["time"].size for i in range(len(ts_dt))])
dmap = hv.DynamicMap(rescale, streams=[range_stream])
dmap

@philippjfr
Copy link
Collaborator

So, I'm wondering whether each channel should be downsampled individually (looped) or together (stacked)?

Timeseries should be downsampled individually (although it can of course happen in parallel) since the LTTB algorithm attempts to preserve the structure of each trace separately. But as you point out that means that their coordinates are no longer aligned making it much more difficult and expensive to store them. So I'm not 100% sure what the best approach is here. We may have to do something simple like averaging across channels and then computing the downsampling indices on the averaged data.

@droumis
Copy link
Collaborator Author

droumis commented Feb 6, 2024

So, I'm wondering whether each channel should be downsampled individually (looped) or together (stacked)?

hmmm... that's tricky. In the short term, I think we just need to make a hard tradeoff here. Using LTTB is ideal for a single timeseries, but trying to use it on a full dataset and still maintain alignment through some approach like averaging across channels feels like a whole research project. For this first pass, I suggest we focus on prioritizing speed and just apply decimation to keep the timeseries aligned. I know decimation is a bit of a dirty word around here but I think it's a necessary first step.

In the future we can revisit, and perhaps if a user zooms in [enough] then it becomes more reasonable to apply LTTB per timeseries for the data in the viewport.

@philippjfr, thoughts?

@philippjfr
Copy link
Collaborator

Not sure if decimation or downsampling is the better approach here. Certainly we could simply apply regrid on the 2D array keeping the number of channels fixed so it only resamples along the time dimension.

@droumis
Copy link
Collaborator Author

droumis commented Feb 13, 2024

regrid is from HoloViews. It uses raster resampling from Datashader.

Ensure the number of pixels in height matches the number of channels. So preserve the channels dimension but resample the time dimension.

@ahuang11
Copy link
Collaborator

ahuang11 commented Feb 16, 2024

Thanks to the advice from Deepak Cherian and Sam Levang (pydata/xarray#8695 (comment)), I was able to make a mini breakthrough of keeping all the downsampled times, while keeping it blazingly fast (if I'm not mistaken... :P). 10 channels can be downsampled and exported to zarr in 1 minute 20 seconds and 1.48 GBs of disk space (potentially can save more if I use different dtypes and reduce the number of zoom levels)
image

It's a combination of apply_ufunc + recreating the data array with the unique coordinates as a multi-dimensional coordinate.

Code
import h5py
import xarray as xr
import dask.array as da
import datatree as dt
from tsdownsample import MinMaxLTTBDownsampler


def _help_downsample(data, time, n_out):
    indices = MinMaxLTTBDownsampler().downsample(time, data, n_out=n_out)
    return data[indices], indices


def apply_downsample(ts_ds, n_out):
    print(n_out)
    times = ts_ds["time"]
    ts_ds_downsampled, time_indices = xr.apply_ufunc(
        _help_downsample,
        ts_ds,
        times,
        kwargs=dict(n_out=n_out),
        input_core_dims=[["time"], ["time"]],
        output_core_dims=[["time"], ["indices"]],
        exclude_dims=set(("time",)),
        vectorize=False,
        dask="parallelized",
        dask_gufunc_kwargs=dict(output_sizes={"time": n_out, "indices": n_out}),
    )
    ts_da_downsampled = ts_ds_downsampled.to_dataarray("channel")
    times_subset = times.values[time_indices.to_array().values]
    return xr.DataArray(
        ts_da_downsampled.data,
        dims=["channel", "time"],
        coords={"multi_time": (("channel", "time"), times_subset)},
        name="data",
    )


def build_dataset(f, data_key, dims):
    coords = {f[dim] for dim in dims.values()}
    data = f[data_key]
    ds = xr.DataArray(
        da.from_array(data, name="data", chunks=(data.shape[0], 1)),
        dims=dims,
        coords=coords,
    ).to_dataset()
    return ds


# adapted from https://github.com/carbonplan/ndpyramid/blob/main/ndpyramid/core.py
def multiscales_template(
    *,
    datasets: list = None,
    type: str = "",
    method: str = "",
    version: str = "",
    args: list = None,
    kwargs: dict = None,
):
    if datasets is None:
        datasets = []
    if args is None:
        args = []
    if kwargs is None:
        kwargs = {}
    # https://forum.image.sc/t/multiscale-arrays-v0-1/37930
    return [
        {
            "datasets": datasets,
            "type": type,
            "metadata": {
                "method": method,
                "version": version,
                "args": args,
                "kwargs": kwargs,
            },
        }
    ]


def pyramid_downsample(ds: xr.Dataset, *, factors: list[int], **kwargs) -> dt.DataTree:
    """Create a multiscale pyramid via coarsening of a dataset by given factors

    Parameters
    ----------
    ds : xarray.Dataset
        The dataset to coarsen.
    factors : list[int]
        The factors to coarsen by.
    kwargs : dict
        Additional keyword arguments to pass to xarray.Dataset.coarsen.
    """
    ds = ds["data"].to_dataset("channel")

    # multiscales spec
    save_kwargs = locals()
    del save_kwargs["ds"]

    attrs = {
        "multiscales": multiscales_template(
            datasets=[{"path": str(i)} for i in range(len(factors))],
            type="pick",
            method="pyramid_downsample",
            version="0.1",
            kwargs=save_kwargs,
        )
    }

    # set up pyramid
    plevels = {}

    # pyramid data
    for key, factor in enumerate(factors):
        factor += 1
        result = apply_downsample(ds, len(ds["time"]) // factor)
        plevels[str(key)] = result

    plevels["/"] = xr.Dataset(attrs=attrs)
    return dt.DataTree.from_dict(plevels)


f = h5py.File("allensdk_cache/session_715093703/probe_810755797_lfp.nwb", "r")
ts_ds = build_dataset(
    f,
    "acquisition/probe_810755797_lfp_data/data",
    {
        "time": "acquisition/probe_810755797_lfp_data/timestamps",
        "channel": "acquisition/probe_810755797_lfp_data/electrodes",
    },
).isel(channel=slice(0, 10))

In regards to the plot, I implemented part of it.

Code
import numpy as np
import holoviews as hv
from holoviews.operation.datashader import datashade

hv.extension("bokeh")


def _extract_ds(ts_dt, level, channel):
    ds = (
        ts_dt[str(level)]
        .isel(channel=channel)
        .ds.swap_dims({"time": "multi_time"})
        .rename({"multi_time": "time"})
    )
    return ds


def rescale(x_range):
    if x_range is None:
        x_range = time_da.min().item(), time_da.max().item()

    x_slice = slice(*[float(x) for x in x_range])
    sub_length = time_da.sel(time=x_slice).size
    min_index = np.argmin(np.abs(lengths - sub_length))
    zoom_level = max_index - min_index - 1
    curves = hv.Overlay(kdims="Channel")
    for channel in channels:
        sub_ds = _extract_ds(ts_dt, zoom_level, channel).sel(time=x_slice)
        curves *= hv.Curve(sub_ds, ["time"], ["data"], label=f"channel {channel}").opts(
            color="black",
            line_width=1,
            subcoordinate_y=True,
            subcoordinate_scale=3,
        )

    return curves.opts(
        xlabel="Time (s)",
        ylabel="Channel",
        show_legend=False,
        padding=0,
        aspect=1.5,
        responsive=True,
        shared_axes=False,
        framewise=False,
    )


time_da = _extract_ds(ts_dt, 0, 0)["time"]
channels = ts_dt["0"].ds["channel"].values
lengths = np.array(
    [ts_dt[f"{i}"].isel(channel=0).ds["time"].size for i in range(len(ts_dt))]
)
max_index = len(lengths)
range_stream = hv.streams.RangeX()
dmap = hv.DynamicMap(rescale, streams=[range_stream])
dmap

The issue is that the range stream doesn't trigger, maybe because it's an overlay??

image

I've tried to work around this, by creating an overlay of each individual dynamicmap:

Code
from functools import partial

import numpy as np
import holoviews as hv

hv.extension("bokeh")


def _extract_ds(ts_dt, level, ch):
    ds = (
        ts_dt[str(level)]
        .isel(channel=ch)
        .ds.swap_dims({"time": "multi_time"})
        .rename({"multi_time": "time"})
    )
    return ds


def rescale(ch, x_range):
    if x_range is None:
        x_range = time_da[0].item(), time_da[-1].item()

    x_slice = slice(*[float(x) for x in x_range])
    sub_length = time_da.sel(time=x_slice).size
    min_index = np.argmin(np.abs(lengths - sub_length))
    zoom_level = max_index - min_index - 1
    sub_ds = _extract_ds(ts_dt, zoom_level, ch).sel(time=x_slice)
    curve = hv.Curve(sub_ds, ["time"], ["data"], label=f"ch{ch}").opts(
        color="black",
        line_width=1,
        subcoordinate_y=True,
        subcoordinate_scale=3,
    )
    print(zoom_level)
    return curve


time_da = _extract_ds(ts_dt, 0, 0)["time"]
channels = ts_dt["0"].ds["channel"].values
lengths = np.array(
    [ts_dt[f"{i}"].isel(channel=0).ds["time"].size for i in range(len(ts_dt))]
)
max_index = len(lengths)

curves = hv.Overlay(kdims=["Channel"])
range_stream = hv.streams.RangeX()
for channel in channels:
    curves *= hv.DynamicMap(
        partial(rescale, ch=channel), streams=[range_stream]
    )

curves.opts(
    xlabel="Time (s)",
    ylabel="Channel",
    show_legend=False,
)

However, still doesn't work.

Screen.Recording.2024-02-15.at.4.08.47.PM.mov

@ahuang11
Copy link
Collaborator

Trying it on all channels:
image

Took ~10 mins and outputs 10 GBs

@ahuang11
Copy link
Collaborator

The bug is subcoordinates break overlaid dynamicmaps

        subcoordinate_y=True,
        subcoordinate_scale=3,

@droumis
Copy link
Collaborator Author

droumis commented Feb 19, 2024

oooo, multi-dimensional coordinates, nice! Great work so far @ahuang11

Could you spend a bit of time filing MRE HoloViews issues for the subcoordinates breaking overlaid DMap and for the range stream not triggering? These both seem pretty important to address.

I'm not sure about the performance implications of creating a DMap for every channel vs using a single DMap with an overlay of curves. @philippjfr, any intuition about this?

@philippjfr
Copy link
Collaborator

For prioritizing Andrew's next steps, what do you think about the following, which delays the data interface for HoloViews until after both use-cases are explored:

Those steps sound good. I'd also like to see some profiling to see where we spend most of the time currently.

@ahuang11
Copy link
Collaborator

Is there a tool native to profiling DynamicMaps or do I just wrap time.perf_counter() around?

@philippjfr
Copy link
Collaborator

I'd use %% prun in a notebook, and time both the initial render and sending an update on the stream.

@ahuang11
Copy link
Collaborator

ahuang11 commented Feb 27, 2024

For the deep image stack use-case, I didn't need to use a pyramid to get it running efficiently.

I think I only needed to persist the right/bottom datasets because I think upon VLine/HLine update, it triggered unnecessary re-computation when the data was static. I also refactored hvPlot into HoloViews, but not sure that was necessary or not.

Screen.Recording.2024-02-26.at.6.03.21.PM.mov
import xarray as xr
import panel as pn

pn.extension(throttled=True)
import holoviews as hv
from holoviews.operation.datashader import rasterize

hv.extension("bokeh")

DATA_ARRAY = "10000frames"

DATA_PATH = f"miniscope/miniscope_sim_{DATA_ARRAY}.zarr"

ldataset = xr.open_zarr(DATA_PATH, chunks="auto")

data = ldataset[DATA_ARRAY]

FRAMES_PER_SECOND = 30
FRAMES = data.coords["frame"].values


def plot_image(value):
    return hv.Image(data.sel(frame=value), kdims=["width", "height"]).opts(
        cmap="Viridis",
        frame_height=400,
        frame_width=400,
        colorbar=False,
    )


# Create a video player widget
video_player = pn.widgets.Player(
    length=len(data.coords["frame"]),
    interval=1000 // FRAMES_PER_SECOND,  # ms
    value=int(FRAMES.min()),
    max_width=400,
    max_height=90,
    loop_policy="loop",
    sizing_mode="stretch_width",
)

# Create the main plot
main_plot = hv.DynamicMap(
    plot_image, kdims=["value"], streams=[video_player.param.value]
)

# frame indicator lines on side plots
line_opts = dict(color="red", alpha=0.6, line_width=3)
dmap_hline = hv.DynamicMap(pn.bind(lambda value: hv.HLine(value), video_player)).opts(
    **line_opts
)
dmap_vline = hv.DynamicMap(pn.bind(lambda value: hv.VLine(value), video_player)).opts(
    **line_opts
)

# height side view
right_data = data.mean(["width"]).persist()
right_plot = rasterize(
    hv.Image(right_data, kdims=["frame", "height"]).opts(
        cmap="Viridis",
        frame_height=400,
        frame_width=200,
        colorbar=False,
        title="_",
    )
)


# width side view
bottom_data = data.mean(["height"]).persist()
bottom_plot = rasterize(
    hv.Image(bottom_data, kdims=["width", "frame"]).opts(
        cmap="Viridis",
        frame_height=200,
        frame_width=400,
        colorbar=False,
    )
)

video_player.margin = (20, 20, 20, 70)  # center widget over main

sim_app = pn.Column(
    video_player,
    pn.Row(main_plot, right_plot * dmap_vline),
    bottom_plot * dmap_hline,
)

sim_app

It seems that since the data is only about 400 MBs after taking the average, I can load it for even better performance.

image

Takes about 10 seconds to initialize the plots on an M2 Pro.

Screen.Recording.2024-02-26.at.6.10.33.PM.mov

@ahuang11
Copy link
Collaborator

ahuang11 commented Feb 27, 2024

I tested the miniscope with rasterize's streams only set to RangeXY and without load/persist. The mean computation still trigger on change of Player value @philippjfr

Screen.Recording.2024-02-27.at.12.12.15.PM.mov
import xarray as xr
import panel as pn

pn.extension(throttled=True)
import holoviews as hv
from holoviews.operation.datashader import rasterize

hv.extension("bokeh")

DATA_ARRAY = "10000frames"

DATA_PATH = f"miniscope/miniscope_sim_{DATA_ARRAY}.zarr"

ldataset = xr.open_zarr(DATA_PATH, chunks="auto")

data = ldataset[DATA_ARRAY]

FRAMES_PER_SECOND = 30
FRAMES = data.coords["frame"].values


def plot_image(value):
    return hv.Image(data.sel(frame=value), kdims=["width", "height"]).opts(
        cmap="Viridis",
        frame_height=400,
        frame_width=400,
        colorbar=False,
    )


# Create a video player widget
video_player = pn.widgets.Player(
    length=len(data.coords["frame"]),
    interval=1000 // FRAMES_PER_SECOND,  # ms
    value=int(FRAMES.min()),
    max_width=400,
    max_height=90,
    loop_policy="loop",
    sizing_mode="stretch_width",
)

# Create the main plot
main_plot = hv.DynamicMap(
    plot_image, kdims=["value"], streams=[video_player.param.value]
)

# frame indicator lines on side plots
line_opts = dict(color="red", alpha=0.6, line_width=3)
dmap_hline = hv.DynamicMap(pn.bind(lambda value: hv.HLine(value), video_player)).opts(
    **line_opts
)
dmap_vline = hv.DynamicMap(pn.bind(lambda value: hv.VLine(value), video_player)).opts(
    **line_opts
)

from holoviews.streams import RangeXY

# height side view
right_data = data.mean(["width"])
right_plot = rasterize(
    hv.Image(right_data, kdims=["frame", "height"]).opts(
        cmap="Viridis",
        frame_height=400,
        frame_width=200,
        colorbar=False,
        title="_",
    ),
    streams=[RangeXY()],
)


# width side view
bottom_data = data.mean(["height"])
bottom_plot = rasterize(
    hv.Image(bottom_data, kdims=["width", "frame"]).opts(
        cmap="Viridis",
        frame_height=200,
        frame_width=400,
        colorbar=False,
    ),
    streams=[RangeXY()],
)

video_player.margin = (20, 20, 20, 70)  # center widget over main

sim_app = pn.Column(
    video_player,
    pn.Row(main_plot, right_plot * dmap_vline),
    bottom_plot * dmap_hline,
)

sim_app

@philippjfr
Copy link
Collaborator

I tested the miniscope with rasterize's streams only set to RangeXY and without load/persist. The mean computation still trigger on change of Player value @philippjfr

Please try to make a minimum reproducible example.

@ahuang11
Copy link
Collaborator

See holoviz/holoviews#6135

@ahuang11
Copy link
Collaborator

ahuang11 commented Feb 28, 2024

Made it depend on plot_size instead, where it finds the width of the plot and subtracts it with the sliced zoom level's size, and finds the closest one.

        sizes = [
            _extract_ds(ts_dt, zoom_level, 0)["time"].sel(time=time_slice).size
            for zoom_level in range(num_levels)
        ]
        zoom_level = np.argmin(np.abs(np.array(sizes) - width))

However, I think it still needs a zoom multiplier due to the number of channels (takes a while to load).

Screen.Recording.2024-02-27.at.4.20.56.PM.mov
import numpy as np
import panel as pn
import datatree as dt
import holoviews as hv
from scipy.stats import zscore
from holoviews.plotting.links import RangeToolLink
from holoviews.operation.datashader import rasterize
from bokeh.models.tools import WheelZoomTool, HoverTool

hv.extension("bokeh")


CLIM_MUL = 0.5
MAX_CHANNELS = 40
X_PADDING = 0.2  # padding for the x_range


def _extract_ds(ts_dt, level, channel):
    ds = (
        ts_dt[str(level)]
        .sel(channel=channel)
        .ds.swap_dims({"time": "multi_time"})
        .rename({"multi_time": "time"})
    )
    return ds


def rescale(x_range, y_range, width, scale, height):
    if x_range is None:
        x_range = time_da.min().item(), time_da.max().item()
    if y_range is None:
        y_range = 0, num_channels
    x_padding = (x_range[1] - x_range[0]) * X_PADDING
    time_slice = slice(x_range[0] - x_padding, x_range[1] + x_padding)

    if width is None or height is None:
        zoom_level = num_levels - 1
    else:
        sizes = [
            _extract_ds(ts_dt, zoom_level, 0)["time"].sel(time=time_slice).size
            for zoom_level in range(num_levels)
        ]
        zoom_level = np.argmin(np.abs(np.array(sizes) - width))

    curves = hv.Overlay(kdims="Channel")
    for channel in channels:
        hover = HoverTool(
            tooltips=[
                ("Channel", str(channel)),
                ("Time", "$x s"),
                ("Amplitude", "$y µV"),
            ]
        )
        sub_ds = _extract_ds(ts_dt, zoom_level, channel).sel(time=time_slice).load()
        curve = hv.Curve(sub_ds, ["time"], ["data"], label=f"ch{channel}").opts(
            color="black",
            line_width=1,
            subcoordinate_y=True,
            subcoordinate_scale=3,
            default_tools=["pan", "reset", WheelZoomTool(), hover],
        )
        curves *= curve
    return curves.opts(
        xlabel="Time (s)",
        ylabel="Channel",
        title=f"level {zoom_level} ({x_range[0]:.2f}s - {x_range[1]:.2f}s)",
        show_legend=False,
        padding=0,
        aspect=1.5,
        responsive=True,
        framewise=True,
        axiswise=True,
    )


ts_dt = dt.open_datatree("pyramid.zarr", engine="zarr").sel(
    channel=slice(0, MAX_CHANNELS)
)
num_levels = len(ts_dt)
time_da = _extract_ds(ts_dt, 0, 0)["time"]
channels = ts_dt["0"].ds["channel"].values
num_channels = len(channels)
data = ts_dt["0"].ds["data"].values.T
range_stream = hv.streams.RangeXY()
size_stream = hv.streams.PlotSize()
dmap = hv.DynamicMap(rescale, streams=[size_stream, range_stream])

y_positions = range(num_channels)
yticks = [(i, ich) for i, ich in enumerate(channels)]
z_data = zscore(data.T, axis=1)

minimap = rasterize(
    hv.Image((time_da, y_positions, z_data), ["Time (s)", "Channel"], "Amplitude (uV)")
)
minimap = minimap.opts(
    cmap="RdBu_r",
    xlabel="",
    yticks=[yticks[0], yticks[-1]],
    toolbar="disable",
    height=120,
    responsive=True,
    clim=(-z_data.std() * CLIM_MUL, z_data.std() * CLIM_MUL),
)
tool_link = RangeToolLink(
    minimap,
    dmap,
    axes=["x", "y"],
    boundsx=(0, time_da.max().item() // 2),
    boundsy=(0, len(channels)),
)
pn.template.FastListTemplate(main=[(dmap + minimap).cols(1)]).show()

I also used Simon's suggestion of instantiating a zoom tool manually thru bokeh, but still can't get the Y-range to work properly in subcoordinate and the box is unlinked. MRE below:

holoviz/holoviews#6136

@droumis droumis changed the title POC Xarray-Datatree-Based Multi-Resolution Generator and Dynamic Accessor POC Xarray-Datatree-Based Multi-(Time)-Resolution Generator and Dynamic Accessor Mar 6, 2024
@droumis droumis changed the title POC Xarray-Datatree-Based Multi-(Time)-Resolution Generator and Dynamic Accessor Multi-(Time)-Scale Generator and Dynamic Accessor Mar 6, 2024
@ahuang11
Copy link
Collaborator

ahuang11 commented Mar 8, 2024

comment about potential options for determining downscale factors

Regarding that, since zoom levels scale in powers (https://wiki.openstreetmap.org/wiki/Zoom_levels, e.g. z=0, 1 tile, z=1, 4 tiles, z=2, 16 tiles), I came up with a formula to determine the downscale factors to cover the entire range. It depends on the length of the data, the typical screen width that the users will view this in, and desired number of zoom levels (num_factors).

data_length = ts_ds["time"].size
screen_width = 1500 (in px)
num_factors = 4

target_factor = data_length / screen_width
max_zoom = int(np.log2(target_factor))
all_factors = 2 ** np.arange(max_zoom + 1)
sub_factors = all_factors[::max_zoom // (num_factors - 1)]
if sub_factors[-1] != sub_factors[-1]:
    sub_factors = np.append(sub_factors, all_factors[-1])
ts_dt = pyramid_downsample(ts_ds, factors=sub_factors)
ts_dt.to_zarr("pyramid.zarr", mode="w")

@ahuang11
Copy link
Collaborator

ahuang11 commented Mar 9, 2024

Code
import h5py
import numpy as np
import xarray as xr
import dask.array as da
import datatree as dt
from tsdownsample import MinMaxLTTBDownsampler


def _help_downsample(data, time, n_out):
    print(f"{len(time)} -> {n_out} samples")
    indices = MinMaxLTTBDownsampler().downsample(time, data, n_out=n_out)
    return data[indices], indices


def apply_downsample(ts_da, n_out):
    times = ts_da["time"]
    if n_out != len(ts_da["time"]):
        print("Downsampling")
        ts_ds_downsampled, time_indices = xr.apply_ufunc(
            _help_downsample,
            ts_da,
            times,
            kwargs=dict(n_out=n_out),
            input_core_dims=[["time"], ["time"]],
            output_core_dims=[["time"], ["indices"]],
            exclude_dims=set(("time",)),
            vectorize=False,
            dask="parallelized",
            dask_gufunc_kwargs=dict(output_sizes={"time": n_out, "indices": n_out}),
        )
        ts_da_downsampled = ts_ds_downsampled.to_dataarray("channel")
        times_subset = times.values[time_indices.to_array().values]
    else:
        print("No downsampling")
        ts_da_downsampled = ts_da.to_dataarray("channel")
        times_subset = np.vstack(
            [times.values for _ in range(ts_da_downsampled["channel"].size)]
        )
    return xr.DataArray(
        ts_da_downsampled.values,
        dims=["channel", "time"],
        coords={"multi_time": (("channel", "time"), times_subset)},
        name="data",
    )


def build_dataset(f, data_key, dims, use_dask=True):
    coords = {key: f[value] for key, value in dims.items()}
    data = f[data_key]
    if len(dims) != data.ndim:
        dims["channel"] = None
    if use_dask:
        array = da.from_array(data, name="data", chunks=(data.shape[0], 5))
    else:
        array = data
    ds = xr.DataArray(
        array,
        dims=dims,
        coords=coords,
        name="data"
    ).to_dataset()
    return ds


# adapted from https://github.com/carbonplan/ndpyramid/blob/main/ndpyramid/core.py
def multiscales_template(
    *,
    datasets: list = None,
    type: str = "",
    method: str = "",
    version: str = "",
    args: list = None,
    kwargs: dict = None,
):
    if datasets is None:
        datasets = []
    if args is None:
        args = []
    if kwargs is None:
        kwargs = {}
    # https://forum.image.sc/t/multiscale-arrays-v0-1/37930
    return [
        {
            "datasets": datasets,
            "type": type,
            "metadata": {
                "method": method,
                "version": version,
                "args": args,
                "kwargs": kwargs,
            },
        }
    ]


def pyramid_downsample(ds: xr.Dataset, *, factors: list[int], **kwargs) -> dt.DataTree:
    """Create a multiscale pyramid via coarsening of a dataset by given factors

    Parameters
    ----------
    ds : xarray.Dataset
        The dataset to coarsen.
    factors : list[int]
        The factors to coarsen by.
    kwargs : dict
        Additional keyword arguments to pass to xarray.Dataset.coarsen.
    """
    ds = ds["data"].to_dataset("channel")
    factors = np.array(factors).astype(int).tolist()

    # multiscales spec
    save_kwargs = locals()
    del save_kwargs["ds"]

    attrs = {
        "multiscales": multiscales_template(
            datasets=[{"path": str(i)} for i in range(len(factors))],
            type="pick",
            method="pyramid_downsample",
            version="0.1",
            kwargs=save_kwargs,
        )
    }

    # set up pyramid
    plevels = {}

    # pyramid data
    for key, factor in enumerate(factors):
        factor += 1
        result = apply_downsample(ds, len(ds["time"]) // factor)
        plevels[str(key)] = result

    plevels["/"] = xr.Dataset(attrs=attrs)
    return dt.DataTree.from_dict(plevels)

Fixed a couple bugs with loading and I did some timings too.

For preprocessing, the bottleneck is downsampling:

probe_810755797_lfp.nwb (loaded)
(~0.5min to reorganize data into xarray, 1.5min to run, 0.16 min to export)
4 factors: 2.25 mins
8 factors: 5.5 mins

image

probe_810755797_lfp.nwb (persisted)
4 factors: 1.5 mins

ephys_sim_neuropixels_10s_384ch.h5 (persisted)
4 factors: 15 seconds!

ephys_sim_neuropixels_200s_384ch.h5 (persisted)
4 factors: 4.75 mins

@ahuang11
Copy link
Collaborator

ahuang11 commented Mar 9, 2024

Code
import numpy as np
import panel as pn
import datatree as dt
import holoviews as hv
from scipy.stats import zscore
from holoviews.plotting.links import RangeToolLink
from holoviews.operation.datashader import rasterize
from bokeh.models.tools import WheelZoomTool, HoverTool

hv.extension("bokeh")


CLIM_MUL = 0.5
MAX_CHANNELS = 100
X_PADDING = 0.2  # padding for the x_range


def _extract_ds(ts_dt, level, channel):
    ds = (
        ts_dt[str(level)]
        .sel(channel=channel)
        .ds.swap_dims({"time": "multi_time"})
        .rename({"multi_time": "time"})
    )
    return ds


def rescale(x_range, y_range, width, scale, height):
    import time

    s = time.perf_counter()
    print(f"- Update triggered! {width=} {x_range=}")
    if x_range is None:
        x_range = time_da.min().item(), time_da.max().item()
    if y_range is None:
        y_range = 0, num_channels
    x_padding = (x_range[1] - x_range[0]) * X_PADDING
    time_slice = slice(x_range[0] - x_padding, x_range[1] + x_padding)

    if width is None or height is None:
        zoom_level = num_levels - 1
        size = data.size
    else:
        sizes = [
            _extract_ds(ts_dt, zoom_level, 0)["time"].sel(time=time_slice).size
            for zoom_level in range(num_levels)
        ]
        zoom_level = np.argmin(np.abs(np.array(sizes) - width))
        size = sizes[zoom_level]
    e = time.perf_counter()
    print(f"Zoom level computation took {e-s:.2f}s")

    title = (
        f"level {zoom_level} ({x_range[0]:.2f}s - {x_range[1]:.2f}s) "
        f"(WxH: {width}x{height}) (length: {size})"
    )
    # if zoom_level == pn.state.cache.get("current_zoom_level") and pn.state.cache.get(
    #     "curves"
    # ):
    #     cached_x_range = pn.state.cache["x_range"]
    #     if x_range[0] >= cached_x_range[0] and x_range[1] <= cached_x_range[1]:
    #         print(f"Using cached curves! {zoom_level=}")
    #         if x_range != cached_x_range:
    #             print(f"Different x_range: {x_range} {cached_x_range}")
    #         return pn.state.cache["curves"].opts(title=title)

    curves = hv.Overlay(kdims="Channel")
    for channel in channels:
        hover = HoverTool(
            tooltips=[
                ("Channel", str(channel)),
                ("Time", "$x s"),
                ("Amplitude", "$y µV"),
            ]
        )
        sub_ds = _extract_ds(ts_dt, zoom_level, channel).sel(time=time_slice).load()
        curve = hv.Curve(sub_ds, ["time"], ["data"], label=f"ch{channel}").opts(
            color="black",
            line_width=1,
            subcoordinate_y=True,
            subcoordinate_scale=3,
            default_tools=["pan", "reset", WheelZoomTool(), hover],
        )
        curves *= curve
    print(f"Overlaying curves took {time.perf_counter()-e:.2f}s")

    curves = curves.opts(
        xlabel="Time (s)",
        ylabel="Channel",
        title=title,
        show_legend=False,
        padding=0,
        aspect=1.5,
        responsive=True,
        framewise=True,
        axiswise=True,
    )
    pn.state.cache["current_zoom_level"] = zoom_level
    pn.state.cache["x_range"] = x_range
    pn.state.cache["curves"] = curves
    print(f"Using updated curves! {x_range} {zoom_level}\n\n")
    return curves


ts_dt = dt.open_datatree("pyramid_4_evenly_factored.zarr", engine="zarr").sel(
    channel=slice(0, MAX_CHANNELS)
)
num_levels = len(ts_dt)
time_da = _extract_ds(ts_dt, 0, 0)["time"]
channels = ts_dt["0"].ds["channel"].values
num_channels = len(channels)
data = ts_dt["0"].ds["data"].values.T
range_stream = hv.streams.RangeXY()
size_stream = hv.streams.PlotSize()
dmap = hv.DynamicMap(rescale, streams=[size_stream, range_stream])

y_positions = range(num_channels)
yticks = [(i, ich) for i, ich in enumerate(channels)]
z_data = zscore(data.T, axis=1)

minimap = rasterize(
    hv.Image((time_da, y_positions, z_data), ["Time (s)", "Channel"], "Amplitude (uV)")
)
minimap = minimap.opts(
    cmap="RdBu_r",
    xlabel="",
    yticks=[yticks[0], yticks[-1]],
    toolbar="disable",
    height=120,
    responsive=True,
    clim=(-z_data.std() * CLIM_MUL, z_data.std() * CLIM_MUL),
)
tool_link = RangeToolLink(
    minimap,
    dmap,
    axes=["x", "y"],
    boundsx=(0, time_da.max().item() // 2),
    boundsy=(0, len(channels)),
)
pn.template.FastListTemplate(main=[(dmap + minimap).cols(1)]).show()

Note that the run times of computing the slices is short.
image

So I think it's a matter of rendering that is slow; and that's in part due to it has to render 90-384 curves (channels) even though it's zoomed on only a couple channels(?). If
holoviz/holoviews#6136 is fixed, we can cut hide the other channels while zoomed in to specific.

Also, on init, it executes the DynamicMap three times before settling. I think this is because on:

  1. initialize without width/height
  2. size stream
  3. range_stream

I tried caching the curves

    if zoom_level == pn.state.cache.get("current_zoom_level") and pn.state.cache.get(
        "curves"
    ):
        cached_x_range = pn.state.cache["x_range"]
        if x_range[0] >= cached_x_range[0] and x_range[1] <= cached_x_range[1]:
            print(f"Using cached curves! {zoom_level=}")
            if x_range != cached_x_range:
                print(f"Different x_range: {x_range} {cached_x_range}")
            return pn.state.cache["curves"].opts(title=title)

However, it doesn't help much because it's all on the rendering side I think..

@droumis
Copy link
Collaborator Author

droumis commented Mar 12, 2024

if sub_factors[-1] != sub_factors[-1]:

Nice work, Andrew! I think the above line is probably supposed to be something like if sub_factors[-1] != all_factors[-1]: right?

factor += 1

Is this to avoid division by zero and/or another reason?

it executes the DynamicMap three times before settling

@philippjfr, is there any way to combine the init execution of DynamicMap to combine updates from multiple streams?

@droumis
Copy link
Collaborator Author

droumis commented Mar 12, 2024

For a limited number of channels, this is pretty good! (Note, there's a range issue with the first channel, which I thought had been resolved already, but hopefully it will be addressed with the work that Simon and Maxime are doing soon for subcoords).

Screen.Recording.2024-03-12.at.2.46.53.PM.mov

However, it looks like once the number of channels exceeds ~20 for this particular dataset (I tested 10, 20, 30, 40, 50), then the zoom level does not update quickly enough when zooming in (at least within a couple of minutes). For larger channel counts, I'm also seeing the DynamicMap get executed several times for each minimap range adjustment, which is definitely harming the performance (may need to try throttling the callback?). In fact, sometimes, it doesn't ever stop executing and seems to get stuck in a loop.

Here's the code, bringing together what Andrew has done above:

Code
# First, download a dataset, e.g.:
# (3GB) datasets.holoviz.org/ephys_sim/v1/ephys_sim_neuropixels_10s_384ch.h5

from pathlib import Path
import h5py
import numpy as np
import xarray as xr
import dask.array as da
import datatree as dt
from tsdownsample import MinMaxLTTBDownsampler


def _help_downsample(data, time, n_out):
    print(f"{len(time)} -> {n_out} samples")
    indices = MinMaxLTTBDownsampler().downsample(time, data, n_out=n_out)
    return data[indices], indices


def apply_downsample(ts_da, n_out):
    times = ts_da["time"]
    if n_out != len(ts_da["time"]):
        print("Downsampling")
        ts_ds_downsampled, time_indices = xr.apply_ufunc(
            _help_downsample,
            ts_da,
            times,
            kwargs=dict(n_out=n_out),
            input_core_dims=[["time"], ["time"]],
            output_core_dims=[["time"], ["indices"]],
            exclude_dims=set(("time",)),
            vectorize=False,
            dask="parallelized",
            dask_gufunc_kwargs=dict(output_sizes={"time": n_out, "indices": n_out}),
        )
        ts_da_downsampled = ts_ds_downsampled.to_dataarray("channel")
        times_subset = times.values[time_indices.to_array().values]
    else:
        print("No downsampling")
        ts_da_downsampled = ts_da.to_dataarray("channel")
        times_subset = np.vstack(
            [times.values for _ in range(ts_da_downsampled["channel"].size)]
        )
    return xr.DataArray(
        ts_da_downsampled.values,
        dims=["channel", "time"],
        coords={"multi_time": (("channel", "time"), times_subset)},
        name="data",
    )


def build_dataset(f, data_key, dims, use_dask=True):
    if data_key not in f:
        raise KeyError(f"{data_key} not found in file.")
    
    for key, value in dims.items():
        if value not in f:
            raise KeyError(f"Dimension '{value}' for key '{key}' not found in file.")
    
    coords = {key: f[value] for key, value in dims.items()}
    data = f[data_key]
    if "channel" not in dims:
        dims["channel"] = np.arange(data.shape[1])
    if use_dask:
        array = da.from_array(data, name="data", chunks=(data.shape[0], 5))
    else:
        array = data
    ds = xr.DataArray(
        array,
        dims=dims,
        coords=coords,
        name="data"
    ).to_dataset()
    return ds


# adapted from https://github.com/carbonplan/ndpyramid/blob/main/ndpyramid/core.py
def multiscales_template(
    *,
    datasets: list = None,
    type: str = "",
    method: str = "",
    version: str = "",
    args: list = None,
    kwargs: dict = None,
):
    if datasets is None:
        datasets = []
    if args is None:
        args = []
    if kwargs is None:
        kwargs = {}
    # https://forum.image.sc/t/multiscale-arrays-v0-1/37930
    return [
        {
            "datasets": datasets,
            "type": type,
            "metadata": {
                "method": method,
                "version": version,
                "args": args,
                "kwargs": kwargs,
            },
        }
    ]


def pyramid_downsample(ds: xr.Dataset, *, factors: list[int], **kwargs) -> dt.DataTree:
    """Create a multiscale pyramid via coarsening of a dataset by given factors

    Parameters
    ----------
    ds : xarray.Dataset
        The dataset to coarsen.
    factors : list[int]
        The factors to coarsen by.
    kwargs : dict
        Additional keyword arguments to pass to xarray.Dataset.coarsen.
    """
    ds = ds["data"].to_dataset("channel")
    factors = np.array(factors).astype(int).tolist()

    # multiscales spec
    save_kwargs = locals()
    del save_kwargs["ds"]

    attrs = {
        "multiscales": multiscales_template(
            datasets=[{"path": str(i)} for i in range(len(factors))],
            type="pick",
            method="pyramid_downsample",
            version="0.1",
            kwargs=save_kwargs,
        )
    }

    # set up pyramid
    plevels = {}

    # pyramid data
    for key, factor in enumerate(factors):
        factor += 1
        result = apply_downsample(ds, len(ds["time"]) // factor)
        plevels[str(key)] = result

    plevels["/"] = xr.Dataset(attrs=attrs)
    return dt.DataTree.from_dict(plevels)


# %% Create time pyramid dataset

PYRAMID_FILE = 'pyramid_neuropix_10s.zarr'
CREATE_PYRAMID = True
if CREATE_PYRAMID:
    
    # f = h5py.File(Path.home() / Path("allensdk_cache/session_715093703/probe_810755797_lfp.nwb"), "r")

    # ts_ds = build_dataset(
    #     f,
    #     "acquisition/probe_810755797_lfp_data/data",
    #     {
    #         "time": "acquisition/probe_810755797_lfp_data/timestamps",
    #         "channel": "acquisition/probe_810755797_lfp_data/electrodes",
    #     },
    # ).isel(channel=list(np.arange(0,4)))

    f = h5py.File(Path.home() / Path("data/ephys_sim_neuropixels/ephys_sim_neuropixels_10s_384ch.h5"), "r")
    ts_ds = build_dataset(
        f,
        "/recordings",
        {
            "time": "timestamps",
            # channels don't have a name in this dataset, but they have other metadata like location on the probe
            # Just exclude the channel dim and have incrementing dummy labels be used instead
        },
    )

    data_length = ts_ds["time"].size
    screen_width = 1500 #(in px)
    num_factors = 4

    target_factor = data_length / screen_width
    max_zoom = int(np.log2(target_factor))
    all_factors = 2 ** np.arange(max_zoom + 1)
    sub_factors = all_factors[::max_zoom // (num_factors - 1)]
    if sub_factors[-1] != all_factors[-1]:
        sub_factors = np.append(sub_factors, all_factors[-1])
    ts_dt = pyramid_downsample(ts_ds, factors=sub_factors)
    ts_dt.to_zarr(PYRAMID_FILE, mode="w")
ts_dt = dt.open_datatree(PYRAMID_FILE, engine="zarr")


# %% Plot 
import numpy as np
import panel as pn
import datatree as dt
import holoviews as hv
from scipy.stats import zscore
from holoviews.plotting.links import RangeToolLink
from holoviews.operation.datashader import rasterize
from bokeh.models.tools import WheelZoomTool, HoverTool

hv.extension("bokeh")


CLIM_MUL = 0.5
MAX_CHANNELS = 20
X_PADDING = 0.2  # padding for the x_range


def _extract_ds(ts_dt, level, channel):
    ds = (
        ts_dt[str(level)]
        .sel(channel=channel)
        .ds.swap_dims({"time": "multi_time"})
        .rename({"multi_time": "time"})
    )
    return ds


def rescale(x_range, y_range, width, scale, height):
    import time

    s = time.perf_counter()
    print(f"- Update triggered! {width=} {x_range=}")
    if x_range is None:
        x_range = time_da.min().item(), time_da.max().item()
    if y_range is None:
        y_range = 0, num_channels
    x_padding = (x_range[1] - x_range[0]) * X_PADDING
    time_slice = slice(x_range[0] - x_padding, x_range[1] + x_padding)

    if width is None or height is None:
        zoom_level = num_levels - 1
        size = data.size
    else:
        sizes = [
            _extract_ds(ts_dt, zoom_level, 0)["time"].sel(time=time_slice).size
            for zoom_level in range(num_levels)
        ]
        zoom_level = np.argmin(np.abs(np.array(sizes) - width))
        size = sizes[zoom_level]
    e = time.perf_counter()
    print(f"Zoom level computation took {e-s:.2f}s")

    title = (
        f"level {zoom_level} ({x_range[0]:.2f}s - {x_range[1]:.2f}s) "
        f"(WxH: {width}x{height}) (length: {size})"
    )
    # if zoom_level == pn.state.cache.get("current_zoom_level") and pn.state.cache.get(
    #     "curves"
    # ):
    #     cached_x_range = pn.state.cache["x_range"]
    #     if x_range[0] >= cached_x_range[0] and x_range[1] <= cached_x_range[1]:
    #         print(f"Using cached curves! {zoom_level=}")
    #         if x_range != cached_x_range:
    #             print(f"Different x_range: {x_range} {cached_x_range}")
    #         return pn.state.cache["curves"].opts(title=title)

    curves = hv.Overlay(kdims="Channel")
    for channel in channels:
        hover = HoverTool(
            tooltips=[
                ("Channel", str(channel)),
                ("Time", "$x s"),
                ("Amplitude", "$y µV"),
            ]
        )
        sub_ds = _extract_ds(ts_dt, zoom_level, channel).sel(time=time_slice).load()
        curve = hv.Curve(sub_ds, ["time"], ["data"], label=f"ch{channel}").opts(
            color="black",
            line_width=1,
            subcoordinate_y=True,
            subcoordinate_scale=1,
            default_tools=["pan", "reset", WheelZoomTool(), hover],
        )
        curves *= curve
    print(f"Overlaying curves took {time.perf_counter()-e:.2f}s")

    curves = curves.opts(
        xlabel="Time (s)",
        ylabel="Channel",
        title=title,
        show_legend=False,
        padding=0,
        aspect=1.5,
        responsive=True,
        framewise=True,
        axiswise=True,
    )
    pn.state.cache["current_zoom_level"] = zoom_level
    pn.state.cache["x_range"] = x_range
    pn.state.cache["curves"] = curves
    print(f"Using updated curves! {x_range} {zoom_level}\n\n")
    return curves

ts_dt = dt.open_datatree(PYRAMID_FILE, engine="zarr").sel(
    channel=slice(0, MAX_CHANNELS)
)
num_levels = len(ts_dt)
time_da = _extract_ds(ts_dt, 0, 0)["time"]
channels = ts_dt["0"].ds["channel"].values
num_channels = len(channels)
data = ts_dt["0"].ds["data"].values.T
range_stream = hv.streams.RangeXY()
size_stream = hv.streams.PlotSize()
dmap = hv.DynamicMap(rescale, streams=[size_stream, range_stream])

y_positions = range(num_channels)
yticks = [(i, ich) for i, ich in enumerate(channels)]
z_data = zscore(data.T, axis=1)

minimap = rasterize(
    hv.Image((time_da, y_positions, z_data), ["Time (s)", "Channel"], "Amplitude (uV)")
)
minimap = minimap.opts(
    cmap="RdBu_r",
    xlabel="",
    yticks=[yticks[0], yticks[-1]],
    toolbar="disable",
    height=120,
    responsive=True,
    clim=(-z_data.std() * CLIM_MUL, z_data.std() * CLIM_MUL),
)
tool_link = RangeToolLink(
    minimap,
    dmap,
    axes=["x", "y"],
    boundsx=(0, time_da.max().item() // 2),
    boundsy=(0, len(channels)),
)
pn.template.FastListTemplate(main=[(dmap + minimap).cols(1)]).servable()

@ahuang11
Copy link
Collaborator

ahuang11 commented Mar 14, 2024

if sub_factors[-1] != all_factors[-1]:

Yes! Thanks for spotting that.

factor += 1

I think I initially conflated factor == zoom_level, e.g. on tile maps, zoom level 0 is the least coarse zoom; I think could refactor to remove that += 1 and add -1 to datasets datasets=[{"path": str(i)} for i in range(len(factors))], Or the easy path would simply be just renaming factors to zoom_level.

the zoom level does not update quickly enough when zooming in (at least within a couple of minutes).

Performance can be significantly improved if we are able to utilize y_range but at the moment: holoviz/holoviews#6136

Another idea is to bias the streams to use coarser zoom levels so that only a max number of points are shown at a given time, depending on the number of channels, e.g. if initial zoom_level was 4 for 10 channels, use 6 instead for 15 channels, and 8 for 20 channels.

which is definitely harming the performance (may need to try throttling the callback

I don't think it's a matter of throttling. I think it's multiple streams re-triggering; related: "the init execution of DynamicMap to combine updates from multiple streams"

@droumis
Copy link
Collaborator Author

droumis commented Mar 19, 2024

Another idea is to bias the streams to use coarser zoom levels so that only a max number of points are shown at a given time, depending on the number of channels, e.g. if initial zoom_level was 4 for 10 channels, use 6 instead for 15 channels, and 8 for 20 channels.

I think this is worth trying. As you increase the number of channels, you generally likely shrink the vertical real estate of each of the individual channels, so downsampling to a courser zoom level seems ok theory, beyond what the x-range stream alone indicates. I think it will require some manual tweaking to find the sweet spot.

@droumis
Copy link
Collaborator Author

droumis commented Mar 26, 2024

@ahuang11
Copy link
Collaborator

ahuang11 commented Mar 26, 2024

Added a PR in ndpyramids here carbonplan/ndpyramid#94 to abstract out some of the functionality so that it's not duplicated and can be simplified to just:

import h5py
import xarray as xr
import dask.array as da
import datatree as dt
from ndpyramid import pyramid_create
from tsdownsample import MinMaxLTTBDownsampler


def _help_downsample(data, time, n_out):
    indices = MinMaxLTTBDownsampler().downsample(time, data, n_out=n_out)
    return data[indices], indices


def apply_downsample(ts_ds, factor, dims):
    dim = dims[0]
    n_out = len(ts_ds["data"]) // factor
    ts_ds_downsampled, indices = xr.apply_ufunc(
        _help_downsample,
        ts_ds["data"],
        ts_ds[dim],
        kwargs=dict(n_out=n_out),
        input_core_dims=[[dim], [dim]],
        output_core_dims=[[dim], ["indices"]],
        exclude_dims=set((dim,)),
        vectorize=True,
        dask="parallelized",
        dask_gufunc_kwargs=dict(output_sizes={dim: n_out, "indices": n_out}),
    )
    ts_ds_downsampled[dim] = ts_ds[dim].isel(time=indices.values[0])
    return ts_ds_downsampled.rename("data")


def build_dataset(f, data_key, dims):
    coords = {f[dim] for dim in dims.values()}
    data = f[data_key]
    ds = xr.DataArray(
        da.from_array(data, name="data", chunks=(data.shape[0], 1)),
        dims=dims,
        coords=coords,
    ).to_dataset()
    return ds


f = h5py.File("allensdk_cache/session_715093703/probe_810755797_lfp.nwb", "r")
ts_ds = build_dataset(
    f,
    "acquisition/probe_810755797_lfp_data/data",
    {
        "time": "acquisition/probe_810755797_lfp_data/timestamps",
        "channel": "acquisition/probe_810755797_lfp_data/electrodes",
    },
).isel(channel=[0, 1, 2, 3, 4])
ts_dt = pyramid_create(
    ts_ds,
    factors=[1, 2],
    dims=["time"],
    func=apply_downsample,
    type_label="pick",
    method_label="pyramid_downsample",
)
ts_dt

@droumis droumis changed the title Multi-(Time)-Scale Generator and Dynamic Accessor [GOAL] Large, Multi-Time-Resolution Generator and Dynamic Accessor Apr 5, 2024
@droumis droumis changed the title [GOAL] Large, Multi-Time-Resolution Generator and Dynamic Accessor [GOAL] Demo viewing of large multi-chan timeseries data with multi-time-resolution generator and dynamic accessor Apr 5, 2024
@droumis droumis removed the status in CZI R5 neuro Apr 5, 2024
@ahuang11
Copy link
Collaborator

ahuang11 commented Apr 8, 2024

The PR carbonplan/ndpyramid#120 is now merged; awaiting next release.

@droumis
Copy link
Collaborator Author

droumis commented Jun 23, 2024

Closing, as this demo is captured in the notebook here via PR 96. Great work @ahuang11, @philippjfr and team

@droumis droumis closed this as completed Jun 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: No status
Development

No branches or pull requests

5 participants