Skip to content

Commit

Permalink
Merge pull request #108 from AllenNeuralDynamics/han_upgrade_streamli…
Browse files Browse the repository at this point in the history
…t_1.41

Streamlit 1.41 upgrade and performance enhancements
  • Loading branch information
hanhou authored Jan 16, 2025
2 parents 1a06f45 + 24ca965 commit cc35ead
Show file tree
Hide file tree
Showing 10 changed files with 373 additions and 265 deletions.
258 changes: 132 additions & 126 deletions code/Home.py

Large diffs are not rendered by default.

178 changes: 95 additions & 83 deletions code/pages/0_Data inventory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from streamlit_plotly_events import plotly_events
import plotly.io as pio
pio.json.config.default_engine = "orjson"

import time
import streamlit_nested_layout
Expand All @@ -22,6 +23,7 @@
)
from util.reformat import formatting_metadata_df
from util.aws_s3 import load_raw_sessions_on_VAST
from util.settings import override_plotly_theme
from Home import init


Expand All @@ -48,15 +50,31 @@
)

# Load QUERY_PRESET from json
with open("data_inventory_QUERY_PRESET.json", "r") as f:
QUERY_PRESET = json.load(f)
@st.cache_data()
def load_presets():
with open("data_inventory_QUERY_PRESET.json", "r") as f:
QUERY_PRESET = json.load(f)

with open("data_inventory_VENN_PRESET.json", "r") as f:
VENN_PRESET = json.load(f)
return QUERY_PRESET, VENN_PRESET

QUERY_PRESET, VENN_PRESET = load_presets()

META_COLUMNS = [
"Han_temp_pipeline (bpod)",
"Han_temp_pipeline (bonsai)",
"VAST_raw_data_on_VAST",
] + [query["alias"] for query in QUERY_PRESET]

X_BIN_SIZE_MAPPER = { # For plotly histogram xbins
"Daily": 1000*3600*24, # Milliseconds
"Weekly": 1000*3600*24*7, # Milliseconds
"Monthly": "M1",
"Quarterly": "M4",
}


@st.cache_data(ttl=3600*12)
def merge_queried_dfs(dfs, queries_to_merge):
# Combine queried dfs using df_unique_mouse_date (on index "subject_id", "session_date" only)
Expand Down Expand Up @@ -235,7 +253,7 @@ def count_true_values(df, time_period, column):
rows=len(columns),
cols=1,
shared_xaxes=True,
vertical_spacing=0.05,
vertical_spacing=0.1,
subplot_titles=columns,
)

Expand All @@ -261,7 +279,7 @@ def count_true_values(df, time_period, column):

# Updating layout
fig.update_layout(
height=200 * len(columns),
height=250 * len(columns),
showlegend=False,
title=f"{time_period} counts",
)
Expand All @@ -270,7 +288,7 @@ def count_true_values(df, time_period, column):
for i, column in enumerate(columns):
fig.add_trace(go.Histogram(
x=df[df[column]==True]["session_date"],
xbins=dict(size="M1"), # Only monthly bins look good
xbins=dict(size=X_BIN_SIZE_MAPPER[time_period]), # Only monthly bins look good
name=column,
marker_color=colors[i],
opacity=0.75
Expand All @@ -281,18 +299,18 @@ def count_true_values(df, time_period, column):
height=500,
bargap=0.05, # Gap between bars of adjacent locations
bargroupgap=0.1, # Gap between bars of the same location
barmode='group', # Grouped style
barmode="group", # Grouped style
showlegend=True,
title="Monthly counts",
legend=dict(
orientation="h", # Horizontal legend
y=-0.2, # Position below the plot
x=0.5, # Center the legend
xanchor="center", # Anchor the legend's x position
yanchor="top" # Anchor the legend's y position
yanchor="top", # Anchor the legend's y position
),
title="Monthly counts"
)

return fig

def app():
Expand Down Expand Up @@ -413,88 +431,82 @@ def app():
)

# --- Venn diagram from presets ---
with open("data_inventory_VENN_PRESET.json", "r") as f:
VENN_PRESET = json.load(f)

if VENN_PRESET:
add_venn_diagrms(df_merged)

@st.fragment
def add_venn_diagrms(df_merged):

cols = st.columns([2, 1])
cols[0].markdown("## Venn diagrams from presets")
with cols[1].expander("Time view settings", expanded=True):
cols_1 = st.columns([1, 1])
if_separate_plots = cols_1[0].checkbox("Separate in subplots", value=True)
if_sync_y_limits = cols_1[0].checkbox(
"Sync Y limits", value=True, disabled=not if_separate_plots
)
time_period = cols_1[1].selectbox(
"Bin size",
["Daily", "Weekly", "Monthly", "Quarterly"],
index=1,
)

cols = st.columns([2, 1])
cols[0].markdown("## Venn diagrams from presets")
with cols[1].expander("Time view settings", expanded=True):
cols_1 = st.columns([1, 1])
if_separate_plots = cols_1[0].checkbox("Separate in subplots", value=True)
if_sync_y_limits = cols_1[0].checkbox(
"Sync Y limits", value=True, disabled=not if_separate_plots
for i_venn, venn_preset in enumerate(VENN_PRESET):
# -- Venn diagrams --
st.markdown(f"### ({i_venn+1}). {venn_preset['name']}")
fig, notes = generate_venn(
df_merged,
venn_preset
)
time_period = cols_1[1].selectbox(
"Bin size",
["Daily", "Weekly", "Monthly", "Quarterly"],
index=1,
disabled=not if_separate_plots,
for note in notes:
st.markdown(note)

cols = st.columns([1, 1])
with cols[0]:
st.pyplot(fig, use_container_width=True)

# -- Show and download df for this Venn --
circle_columns = [c_s["column"] for c_s in venn_preset["circle_settings"]]
# Show histogram over time for the columns and patches in preset
df_this_preset = df_merged[circle_columns]
# Filter out rows that have at least one True in this Venn
df_this_preset = df_this_preset[df_this_preset.any(axis=1)]

# Create a new column to indicate sessions in patches specified by patch_ids like ["100", "101", "110", "111"]
for patch_setting in venn_preset.get("patch_settings", []):
idx = _filter_df_by_patch_ids(
df_this_preset[circle_columns],
patch_setting["patch_ids"]
)
df_this_preset.loc[idx, str(patch_setting["patch_ids"])] = True

for i_venn, venn_preset in enumerate(VENN_PRESET):
# -- Venn diagrams --
st.markdown(f"### ({i_venn+1}). {venn_preset['name']}")
fig, notes = generate_venn(
df_merged,
venn_preset
)
for note in notes:
st.markdown(note)

cols = st.columns([1, 1])
with cols[0]:
st.pyplot(fig, use_container_width=True)

# -- Show and download df for this Venn --
circle_columns = [c_s["column"] for c_s in venn_preset["circle_settings"]]
# Show histogram over time for the columns and patches in preset
df_this_preset = df_merged[circle_columns]
# Filter out rows that have at least one True in this Venn
df_this_preset = df_this_preset[df_this_preset.any(axis=1)]

# Create a new column to indicate sessions in patches specified by patch_ids like ["100", "101", "110", "111"]
for patch_setting in venn_preset.get("patch_settings", []):
idx = _filter_df_by_patch_ids(
df_this_preset[circle_columns],
patch_setting["patch_ids"]
)
df_this_preset.loc[idx, str(patch_setting["patch_ids"])] = True
# Join in other extra columns
df_this_preset = df_this_preset.join(
df_merged[[col for col in df_merged.columns if col not in META_COLUMNS]], how="left"
)

# Join in other extra columns
df_this_preset = df_this_preset.join(
df_merged[[col for col in df_merged.columns if col not in META_COLUMNS]], how="left"
with cols[0]:
download_df(
df_this_preset,
label="Download as CSV for this Venn diagram",
file_name=f"df_{venn_preset['name']}.csv",
)
with st.expander(f"Show dataframe, n = {len(df_this_preset)}"):
st.write(df_this_preset)

with cols[0]:
download_df(
df_this_preset,
label="Download as CSV for this Venn diagram",
file_name=f"df_{venn_preset['name']}.csv",
)
with st.expander(f"Show dataframe, n = {len(df_this_preset)}"):
st.write(df_this_preset)

with cols[1]:
# -- Show histogram over time --
fig = plot_histogram_over_time(
df=df_this_preset.reset_index(),
venn_preset=venn_preset,
time_period=time_period,
if_sync_y_limits=if_sync_y_limits,
if_separate_plots=if_separate_plots,
)
plotly_events(
fig,
click_event=False,
hover_event=False,
select_event=False,
override_height=fig.layout.height * 1.1,
override_width=fig.layout.width,
)
with cols[1]:
# -- Show histogram over time --
fig = plot_histogram_over_time(
df=df_this_preset.reset_index(),
venn_preset=venn_preset,
time_period=time_period,
if_sync_y_limits=if_sync_y_limits,
if_separate_plots=if_separate_plots,
)
override_plotly_theme(fig, font_size_scale=0.9)
st.plotly_chart(fig, use_container_width=True)

st.markdown("---")
st.markdown("---")

# --- User-defined Venn diagram ---
# Multiselect for selecting queries up to three
Expand Down
1 change: 0 additions & 1 deletion code/pages/1_Basic behavior analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from plotly.subplots import make_subplots
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from streamlit_plotly_events import plotly_events
from util.aws_s3 import load_data
from util.streamlit import add_session_filter, data_selector, add_footnote
from scipy.stats import gaussian_kde
Expand Down
2 changes: 1 addition & 1 deletion code/util/aws_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def show_session_level_img_by_key_and_prefix(key, prefix, column=None, other_pat
_f.image(img if img is not None else "https://cdn-icons-png.flaticon.com/512/3585/3585596.png",
output_format='PNG',
caption=f_name.split('/')[-1] if caption and f_name else '',
use_column_width='always',
use_container_width='always',
**kwargs)

return img
Expand Down
2 changes: 2 additions & 0 deletions code/util/foraging_plotly.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np
import plotly.express as px
import plotly.graph_objs as go
import plotly.io as pio
pio.json.config.default_engine = "orjson"


def moving_average(a, n=3) :
Expand Down
3 changes: 3 additions & 0 deletions code/util/plot_autotrain_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio
pio.json.config.default_engine = "orjson"

import streamlit as st
from aind_auto_train.plot.curriculum import get_stage_color_mapper
from aind_auto_train.schema.curriculum import TrainingStage
Expand Down
85 changes: 82 additions & 3 deletions code/util/settings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import plotly.io as pio
pio.json.config.default_engine = "orjson"

# Setting up layout for each session
draw_type_layout_definition = [
[1], # columns in the first row
Expand Down Expand Up @@ -30,6 +33,82 @@
}

# For quick preview
draw_types_quick_preview = [
'1. Choice history',
'3. Logistic regression (Su2022)']
draw_types_quick_preview = ["1. Choice history", "3. Logistic regression (Su2022)"]


# For plotly styling
PLOTLY_FIG_DEFAULT = dict(
font_family="Arial",
legend_font_color='black',
)
PLOTLY_AXIS_DEFAULT = dict(
showline=True,
linewidth=2,
linecolor="black",
showgrid=True,
gridcolor="lightgray",
griddash="solid",
minor_showgrid=False,
minor_gridcolor="lightgray",
minor_griddash="solid",
zeroline=True,
ticks="outside",
tickcolor="black",
ticklen=7,
tickwidth=2,
ticksuffix=" ",
tickfont=dict(
family="Arial",
color="black",
),
)

def override_plotly_theme(
fig,
theme="simple_white",
fig_specs=PLOTLY_FIG_DEFAULT,
axis_specs=PLOTLY_AXIS_DEFAULT,
font_size_scale=1.0,
):
"""
Fix the problem that simply using fig.update_layout(template=theme) doesn't work with st.plotly_chart.
I have to use update_layout to explicitly set the theme.
"""

dict_plotly_template = pio.templates[theme].layout.to_plotly_json()
fig.update_layout(**dict_plotly_template) # First apply the plotly official theme

# Apply settings to all x-axes
for axis in fig.layout:
if axis.startswith('xaxis') or axis.startswith('yaxis'):
fig.layout[axis].update(axis_specs)
fig.layout[axis].update(
tickfont_size=22 * font_size_scale,
title_font_size=22 * font_size_scale,
)
if axis.startswith("yaxis"):
fig.layout[axis].update(title_standoff=10 * font_size_scale)

fig.update_layout(**fig_specs) # Apply settings to the entire figure

# Customize the font of subplot titles
for annotation in fig['layout']['annotations']:
annotation['font'] = dict(
family="Arial", # Font family
size=20 * font_size_scale, # Font size
color="black" # Font color
)

# Figure-level settings
fig.update_layout(
font_size=22 * font_size_scale,
hoverlabel_font_size=17 * font_size_scale,
legend_font_size=17 * font_size_scale,
margin=dict(
l=130 * font_size_scale,
r=50 * font_size_scale,
b=130 * font_size_scale,
t=100 * font_size_scale,
),
)
return
Loading

0 comments on commit cc35ead

Please sign in to comment.