From d909b499d8f8556ac54982021f595b26e15698c6 Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Thu, 9 Jan 2025 22:39:15 +0000 Subject: [PATCH 01/19] enable all xbins for plotly histogram --- code/pages/0_Data inventory.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/code/pages/0_Data inventory.py b/code/pages/0_Data inventory.py index cd07753..66a726e 100644 --- a/code/pages/0_Data inventory.py +++ b/code/pages/0_Data inventory.py @@ -57,6 +57,14 @@ "VAST_raw_data_on_VAST", ] + [query["alias"] for query in QUERY_PRESET] +X_BIN_SIZE_MAPPER = { # For plotly histogram xbins + "Daily": 1000*3600*24, # Milliseconds + "Weekly": 1000*3600*24*7, # Milliseconds + "Monthly": "M1", + "Quarterly": "M4", +} + + @st.cache_data(ttl=3600*12) def merge_queried_dfs(dfs, queries_to_merge): # Combine queried dfs using df_unique_mouse_date (on index "subject_id", "session_date" only) @@ -270,7 +278,7 @@ def count_true_values(df, time_period, column): for i, column in enumerate(columns): fig.add_trace(go.Histogram( x=df[df[column]==True]["session_date"], - xbins=dict(size="M1"), # Only monthly bins look good + xbins=dict(size=X_BIN_SIZE_MAPPER[time_period]), # Only monthly bins look good name=column, marker_color=colors[i], opacity=0.75 @@ -430,7 +438,6 @@ def app(): "Bin size", ["Daily", "Weekly", "Monthly", "Quarterly"], index=1, - disabled=not if_separate_plots, ) for i_venn, venn_preset in enumerate(VENN_PRESET): From 77d758a3ad5b4772efc926f29030569a59d586dc Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Thu, 9 Jan 2025 22:41:31 +0000 Subject: [PATCH 02/19] caching presets --- code/pages/0_Data inventory.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/code/pages/0_Data inventory.py b/code/pages/0_Data inventory.py index 66a726e..a14e201 100644 --- a/code/pages/0_Data inventory.py +++ b/code/pages/0_Data inventory.py @@ -48,8 +48,16 @@ ) # Load QUERY_PRESET from json -with open("data_inventory_QUERY_PRESET.json", "r") as f: - QUERY_PRESET = json.load(f) +@st.cache_data() +def load_presets(): + with open("data_inventory_QUERY_PRESET.json", "r") as f: + QUERY_PRESET = json.load(f) + + with open("data_inventory_VENN_PRESET.json", "r") as f: + VENN_PRESET = json.load(f) + return QUERY_PRESET, VENN_PRESET + +QUERY_PRESET, VENN_PRESET = load_presets() META_COLUMNS = [ "Han_temp_pipeline (bpod)", @@ -421,9 +429,6 @@ def app(): ) # --- Venn diagram from presets --- - with open("data_inventory_VENN_PRESET.json", "r") as f: - VENN_PRESET = json.load(f) - if VENN_PRESET: cols = st.columns([2, 1]) From b24040c89b22e8ec68ecdf42060251ed93d8b3fa Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Thu, 9 Jan 2025 22:48:33 +0000 Subject: [PATCH 03/19] change use_container_width --- code/util/aws_s3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/code/util/aws_s3.py b/code/util/aws_s3.py index 300dffc..9c69e32 100644 --- a/code/util/aws_s3.py +++ b/code/util/aws_s3.py @@ -94,7 +94,7 @@ def show_session_level_img_by_key_and_prefix(key, prefix, column=None, other_pat _f.image(img if img is not None else "https://cdn-icons-png.flaticon.com/512/3585/3585596.png", output_format='PNG', caption=f_name.split('/')[-1] if caption and f_name else '', - use_column_width='always', + use_container_width='always', **kwargs) return img From 6263024e29baf8ef2b69b9bc66b8e17a3d344efc Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Fri, 10 Jan 2025 00:01:06 +0000 Subject: [PATCH 04/19] bump versions --- requirements.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index bcb790a..a617adb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ -streamlit==1.31.0 +streamlit==1.41.1 streamlit-aggrid==0.3.5 streamlit-bokeh3-events==0.1.4 streamlit_dynamic_filters==0.1.9 streamlit-nested-layout==0.1.1 streamlit-plotly-events==0.0.6 -pygwalker==0.4.7 -extra-streamlit-components==0.1.56 +pygwalker==0.4.9.13 +extra-streamlit-components==0.1.71 numpy==1.26.4 pandas==2.2.2 matplotlib==3.9.2 From 9e2e161408b9b2c46d2ca847b3d602b51f4a4f0c Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Fri, 10 Jan 2025 00:15:59 +0000 Subject: [PATCH 05/19] Use fragment in the main page --- code/Home.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/code/Home.py b/code/Home.py index f4c7646..8db30d1 100644 --- a/code/Home.py +++ b/code/Home.py @@ -585,9 +585,12 @@ def app(): if len(st.session_state.df_session_filtered) == 0: st.markdown('## No filtered results!') return - - aggrid_outputs = aggrid_interactive_table_session(df=st.session_state.df_session_filtered, table_height=table_height) - + + aggrid_outputs = aggrid_interactive_table_session( + df=st.session_state.df_session_filtered, + table_height=table_height, + ) + if len(aggrid_outputs['selected_rows']) and not set(pd.DataFrame(aggrid_outputs['selected_rows'] ).set_index(['h2o', 'session']).index ) == set(st.session_state.df_selected_from_dataframe.set_index(['h2o', 'session']).index): @@ -596,6 +599,10 @@ def app(): # if st.session_state.tab_id == "tab_session_x_y": st.rerun() + add_tabs() + +@st.fragment +def add_tabs(): chosen_id = stx.tab_bar(data=[ stx.TabBarItemData(id="tab_auto_train_history", title="🎓 Automatic Training History", description="Track progress"), stx.TabBarItemData(id="tab_session_inspector", title="👀 Session Inspector", description="Select sessions from the table and show plots"), @@ -658,7 +665,7 @@ def app(): spec="./gw_config.json", ) - pygwalker_renderer.render_explore(height=1010, scrolling=False) + pygwalker_renderer.render_explore() elif chosen_id == "tab_session_inspector": with placeholder: From f883477a179c0ff5c2e526c42e73834ce056b4a0 Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Fri, 10 Jan 2025 00:20:00 +0000 Subject: [PATCH 06/19] use fragment in Data inventory --- code/pages/0_Data inventory.py | 142 +++++++++++++++++---------------- 1 file changed, 73 insertions(+), 69 deletions(-) diff --git a/code/pages/0_Data inventory.py b/code/pages/0_Data inventory.py index a14e201..71c5722 100644 --- a/code/pages/0_Data inventory.py +++ b/code/pages/0_Data inventory.py @@ -430,83 +430,87 @@ def app(): # --- Venn diagram from presets --- if VENN_PRESET: + add_venn_diagrms(df_merged) + +@st.fragment +def add_venn_diagrms(df_merged): + + cols = st.columns([2, 1]) + cols[0].markdown("## Venn diagrams from presets") + with cols[1].expander("Time view settings", expanded=True): + cols_1 = st.columns([1, 1]) + if_separate_plots = cols_1[0].checkbox("Separate in subplots", value=True) + if_sync_y_limits = cols_1[0].checkbox( + "Sync Y limits", value=True, disabled=not if_separate_plots + ) + time_period = cols_1[1].selectbox( + "Bin size", + ["Daily", "Weekly", "Monthly", "Quarterly"], + index=1, + ) - cols = st.columns([2, 1]) - cols[0].markdown("## Venn diagrams from presets") - with cols[1].expander("Time view settings", expanded=True): - cols_1 = st.columns([1, 1]) - if_separate_plots = cols_1[0].checkbox("Separate in subplots", value=True) - if_sync_y_limits = cols_1[0].checkbox( - "Sync Y limits", value=True, disabled=not if_separate_plots + for i_venn, venn_preset in enumerate(VENN_PRESET): + # -- Venn diagrams -- + st.markdown(f"### ({i_venn+1}). {venn_preset['name']}") + fig, notes = generate_venn( + df_merged, + venn_preset ) - time_period = cols_1[1].selectbox( - "Bin size", - ["Daily", "Weekly", "Monthly", "Quarterly"], - index=1, + for note in notes: + st.markdown(note) + + cols = st.columns([1, 1]) + with cols[0]: + st.pyplot(fig, use_container_width=True) + + # -- Show and download df for this Venn -- + circle_columns = [c_s["column"] for c_s in venn_preset["circle_settings"]] + # Show histogram over time for the columns and patches in preset + df_this_preset = df_merged[circle_columns] + # Filter out rows that have at least one True in this Venn + df_this_preset = df_this_preset[df_this_preset.any(axis=1)] + + # Create a new column to indicate sessions in patches specified by patch_ids like ["100", "101", "110", "111"] + for patch_setting in venn_preset.get("patch_settings", []): + idx = _filter_df_by_patch_ids( + df_this_preset[circle_columns], + patch_setting["patch_ids"] ) + df_this_preset.loc[idx, str(patch_setting["patch_ids"])] = True - for i_venn, venn_preset in enumerate(VENN_PRESET): - # -- Venn diagrams -- - st.markdown(f"### ({i_venn+1}). {venn_preset['name']}") - fig, notes = generate_venn( - df_merged, - venn_preset - ) - for note in notes: - st.markdown(note) - - cols = st.columns([1, 1]) - with cols[0]: - st.pyplot(fig, use_container_width=True) - - # -- Show and download df for this Venn -- - circle_columns = [c_s["column"] for c_s in venn_preset["circle_settings"]] - # Show histogram over time for the columns and patches in preset - df_this_preset = df_merged[circle_columns] - # Filter out rows that have at least one True in this Venn - df_this_preset = df_this_preset[df_this_preset.any(axis=1)] - - # Create a new column to indicate sessions in patches specified by patch_ids like ["100", "101", "110", "111"] - for patch_setting in venn_preset.get("patch_settings", []): - idx = _filter_df_by_patch_ids( - df_this_preset[circle_columns], - patch_setting["patch_ids"] - ) - df_this_preset.loc[idx, str(patch_setting["patch_ids"])] = True + # Join in other extra columns + df_this_preset = df_this_preset.join( + df_merged[[col for col in df_merged.columns if col not in META_COLUMNS]], how="left" + ) - # Join in other extra columns - df_this_preset = df_this_preset.join( - df_merged[[col for col in df_merged.columns if col not in META_COLUMNS]], how="left" + with cols[0]: + download_df( + df_this_preset, + label="Download as CSV for this Venn diagram", + file_name=f"df_{venn_preset['name']}.csv", ) + with st.expander(f"Show dataframe, n = {len(df_this_preset)}"): + st.write(df_this_preset) - with cols[0]: - download_df( - df_this_preset, - label="Download as CSV for this Venn diagram", - file_name=f"df_{venn_preset['name']}.csv", - ) - with st.expander(f"Show dataframe, n = {len(df_this_preset)}"): - st.write(df_this_preset) - - with cols[1]: - # -- Show histogram over time -- - fig = plot_histogram_over_time( - df=df_this_preset.reset_index(), - venn_preset=venn_preset, - time_period=time_period, - if_sync_y_limits=if_sync_y_limits, - if_separate_plots=if_separate_plots, - ) - plotly_events( - fig, - click_event=False, - hover_event=False, - select_event=False, - override_height=fig.layout.height * 1.1, - override_width=fig.layout.width, - ) + with cols[1]: + # -- Show histogram over time -- + fig = plot_histogram_over_time( + df=df_this_preset.reset_index(), + venn_preset=venn_preset, + time_period=time_period, + if_sync_y_limits=if_sync_y_limits, + if_separate_plots=if_separate_plots, + ) + plotly_events( + fig, + click_event=False, + hover_event=False, + select_event=False, + override_height=fig.layout.height * 1.1, + override_width=fig.layout.width, + ) - st.markdown("---") + st.markdown("---") # --- User-defined Venn diagram --- # Multiselect for selecting queries up to three From b7e25175e1c3111cd2840354a52195bf6442f17f Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Fri, 10 Jan 2025 00:32:02 +0000 Subject: [PATCH 07/19] currucilum page: improve sorting --- code/Home.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/code/Home.py b/code/Home.py index 8db30d1..b7ca704 100644 --- a/code/Home.py +++ b/code/Home.py @@ -688,7 +688,9 @@ def add_tabs(): elif chosen_id == "tab_auto_train_curriculum": # Automatic training curriculums df_curriculums = st.session_state.curriculum_manager.df_curriculums().sort_values( - by=['curriculum_schema_version', 'curriculum_name', 'curriculum_version']).reset_index().drop(columns='index') + by=['curriculum_version', 'curriculum_schema_version', 'curriculum_name'], + ascending=[False, True, False], + ).reset_index().drop(columns='index') with placeholder: # Show curriculum manager dataframe st.markdown("#### Select auto training curriculums") From 30cbccd85b6245f04139a784e19a2a0fa35a58fe Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Fri, 10 Jan 2025 00:48:15 +0000 Subject: [PATCH 08/19] improve selectbox_wrapper_for_url_query --- code/util/url_query_helper.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/code/util/url_query_helper.py b/code/util/url_query_helper.py index e8fd43f..47f3396 100644 --- a/code/util/url_query_helper.py +++ b/code/util/url_query_helper.py @@ -95,17 +95,21 @@ def checkbox_wrapper_for_url_query(st_prefix, label, key, default, **kwargs): **kwargs, ) -def selectbox_wrapper_for_url_query(st_prefix, label, options, key, default, **kwargs): +def selectbox_wrapper_for_url_query(st_prefix, label, options, key, default, default_override=True, **kwargs): + # If default_override, use default. Otherwise, session_state or query_params has higher priority + if not default_override: + default = ( + st.session_state[key] + if key in st.session_state and st.session_state[key] in options + else st.query_params[key] + if key in st.query_params and st.query_params[key] in options + else default + ) + return st_prefix.selectbox( label, options=options, - index=( - options.index(st.session_state[key]) - if key in st.session_state - else options.index(st.query_params[key]) - if key in st.query_params - else options.index(default) - ), + index=options.index(default), key=key, **kwargs, ) From 202f791fbd00ed55e30ebad118db99ebcc42ccf0 Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Fri, 10 Jan 2025 00:49:02 +0000 Subject: [PATCH 09/19] set default_override=False by default --- code/util/url_query_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/code/util/url_query_helper.py b/code/util/url_query_helper.py index 47f3396..06702e5 100644 --- a/code/util/url_query_helper.py +++ b/code/util/url_query_helper.py @@ -95,7 +95,7 @@ def checkbox_wrapper_for_url_query(st_prefix, label, key, default, **kwargs): **kwargs, ) -def selectbox_wrapper_for_url_query(st_prefix, label, options, key, default, default_override=True, **kwargs): +def selectbox_wrapper_for_url_query(st_prefix, label, options, key, default, default_override=False, **kwargs): # If default_override, use default. Otherwise, session_state or query_params has higher priority if not default_override: default = ( From 84397743146e46452b222cb959d18056ad11cfca Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Fri, 10 Jan 2025 00:55:01 +0000 Subject: [PATCH 10/19] improve curriculum tab --- code/Home.py | 64 ++++++++++++++++++++++++---------------------------- 1 file changed, 29 insertions(+), 35 deletions(-) diff --git a/code/Home.py b/code/Home.py index b7ca704..cae002c 100644 --- a/code/Home.py +++ b/code/Home.py @@ -38,6 +38,7 @@ add_footnote) from util.url_query_helper import (checkbox_wrapper_for_url_query, multiselect_wrapper_for_url_query, + selectbox_wrapper_for_url_query, number_input_wrapper_for_url_query, slider_wrapper_for_url_query, sync_session_state_to_URL, @@ -689,8 +690,9 @@ def add_tabs(): elif chosen_id == "tab_auto_train_curriculum": # Automatic training curriculums df_curriculums = st.session_state.curriculum_manager.df_curriculums().sort_values( by=['curriculum_version', 'curriculum_schema_version', 'curriculum_name'], - ascending=[False, True, False], - ).reset_index().drop(columns='index') + ascending=[False, True, False], + ).reset_index().drop(columns='index').query("curriculum_name != 'Dummy task'") + with placeholder: # Show curriculum manager dataframe st.markdown("#### Select auto training curriculums") @@ -698,51 +700,43 @@ def add_tabs(): # Curriculum drop down selector cols = st.columns([0.8, 0.5, 0.8, 4]) cols[3].markdown(f"(aind_auto_train lib version = {auto_train_version})") + options = list(df_curriculums['curriculum_name'].unique()) - selected_curriculum_name = cols[0].selectbox( - 'Curriculum name', + selected_curriculum_name = selectbox_wrapper_for_url_query( + st_prefix=cols[0], + label='Curriculum name', options=options, - index=options.index(st.session_state['auto_training_curriculum_name']) - if ('auto_training_curriculum_name' in st.session_state) and (st.session_state['auto_training_curriculum_name'] != '') else - options.index(st.query_params['auto_training_curriculum_name']) - if 'auto_training_curriculum_name' in st.query_params and st.query_params['auto_training_curriculum_name'] != '' - else 0, - key='auto_training_curriculum_name' - ) - + default=options[0], + default_override=True, + key='auto_training_curriculum_name', + ) + options = list(df_curriculums[ df_curriculums['curriculum_name'] == selected_curriculum_name ]['curriculum_version'].unique()) - if ('auto_training_curriculum_version' in st.session_state) and (st.session_state['auto_training_curriculum_version'] in options): - default = options.index(st.session_state['auto_training_curriculum_version']) - elif 'auto_training_curriculum_version' in st.query_params and st.query_params['auto_training_curriculum_version'] in options: - default = options.index(st.query_params['auto_training_curriculum_version']) - else: - default = 0 - selected_curriculum_version = cols[1].selectbox( - 'Curriculum version', - options=options, - index=default, - key='auto_training_curriculum_version' + selected_curriculum_version = selectbox_wrapper_for_url_query( + st_prefix=cols[1], + label='Curriculum version', + options=options, + default=options[0], + default_override=True, + key='auto_training_curriculum_version', ) options = list(df_curriculums[ (df_curriculums['curriculum_name'] == selected_curriculum_name) & (df_curriculums['curriculum_version'] == selected_curriculum_version) ]['curriculum_schema_version'].unique()) - if ('auto_training_curriculum_schema_version' in st.session_state) and (st.session_state['auto_training_curriculum_schema_version'] in options): - default = options.index(st.session_state['auto_training_curriculum_schema_version']) - elif 'auto_training_curriculum_schema_version' in st.query_params and st.query_params['auto_training_curriculum_schema_version'] in options: - default = options.index(st.query_params['auto_training_curriculum_schema_version']) - else: - default = 0 - selected_curriculum_schema_version = cols[2].selectbox( - 'Curriculum schema version', + + selected_curriculum_schema_version = selectbox_wrapper_for_url_query( + st_prefix=cols[2], + label='Curriculum schema version', options=options, - index=default, - key='auto_training_curriculum_schema_version' - ) - + default=options[0], + default_override=True, + key='auto_training_curriculum_schema_version', + ) + selected_curriculum = st.session_state.curriculum_manager.get_curriculum( curriculum_name=selected_curriculum_name, curriculum_schema_version=selected_curriculum_schema_version, From 74ff41cd41b20f99e04d7c269d380891cf9f92a5 Mon Sep 17 00:00:00 2001 From: "houhan@gmail.com" Date: Fri, 10 Jan 2025 01:54:24 +0000 Subject: [PATCH 11/19] minor refactor --- code/Home.py | 43 +++++++++++++++++++++---------------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/code/Home.py b/code/Home.py index cae002c..86ca1e5 100644 --- a/code/Home.py +++ b/code/Home.py @@ -72,7 +72,7 @@ def _user_name_mapper(user_name): return name else: return user_name - + @st.cache_resource(ttl=24*3600) def get_pyg_renderer(df, spec="./gw_config.json", **kwargs) -> "StreamlitRenderer": @@ -132,22 +132,20 @@ def session_plot_settings(need_click=True): with st.form(key='session_plot_settings'): st.markdown('##### Show plots for individual sessions ') cols = st.columns([2, 6, 1]) - + session_plot_modes = [f'sessions selected from table or plot', f'all sessions filtered from sidebar'] - st.session_state.selected_draw_sessions = cols[0].selectbox(f'Which session(s) to draw?', - session_plot_modes, - index=session_plot_modes.index(st.session_state['session_plot_mode']) - if 'session_plot_mode' in st.session_state else - session_plot_modes.index(st.query_params['session_plot_mode']) - if 'session_plot_mode' in st.query_params - else 0, - key='session_plot_mode', - ) - + st.session_state.selected_draw_sessions = selectbox_wrapper_for_url_query( + cols[0], + label='Which session(s) to draw?', + options=session_plot_modes, + default=session_plot_modes[0], + key='session_plot_mode', + ) + n_session_to_draw = len(st.session_state.df_selected_from_plotly) \ if 'selected from table or plot' in st.session_state.selected_draw_sessions \ else len(st.session_state.df_session_filtered) - + _ = number_input_wrapper_for_url_query( st_prefix=cols[2], label='number of columns', @@ -156,7 +154,7 @@ def session_plot_settings(need_click=True): default=3, key='session_plot_number_cols', ) - + st.markdown( """