diff --git a/code/Home.py b/code/Home.py
index 766b1c4..02e67e4 100644
--- a/code/Home.py
+++ b/code/Home.py
@@ -15,39 +15,28 @@
# %%
import pandas as pd
import streamlit as st
-from pathlib import Path
-import glob
-import matplotlib.pyplot as plt
import numpy as np
-from datetime import datetime
-import s3fs
import os
-import plotly.express as px
-import plotly
-import plotly.graph_objects as go
-import statsmodels.api as sm
-import json
-
-from PIL import Image, ImageColor
-import streamlit.components.v1 as components
+
import streamlit_nested_layout
from streamlit_plotly_events import plotly_events
from pygwalker.api.streamlit import StreamlitRenderer, init_streamlit_comm
+import extra_streamlit_components as stx
-# To suppress the warning that I set the default value of a widget and also set it in the session state
-from streamlit.elements.utils import _shown_default_value_warning
-_shown_default_value_warning = False
-
-from util.streamlit import (filter_dataframe, aggrid_interactive_table_session,
+from util.streamlit import (aggrid_interactive_table_session,
aggrid_interactive_table_curriculum, add_session_filter, data_selector,
add_xy_selector, add_xy_setting, add_auto_train_manager, add_dot_property_mapper,
_plot_population_x_y)
+from util.aws_s3 import (
+ load_data,
+ draw_session_plots_quick_preview,
+ show_session_level_img_by_key_and_prefix,
+ show_debug_info,
+)
from util.url_query_helper import (
sync_widget_with_query, slider_wrapper_for_url_query, checkbox_wrapper_for_url_query
)
-import extra_streamlit_components as stx
-
from aind_auto_train.curriculum_manager import CurriculumManager
from aind_auto_train.auto_train_manager import DynamicForagingAutoTrainManager
@@ -105,14 +94,6 @@
}
-data_sources = ['bonsai', 'bpod']
-
-s3_nwb_folder = {data: f'aind-behavior-data/foraging_nwb_{data}/' for data in data_sources}
-s3_processed_nwb_folder = {data: f'aind-behavior-data/foraging_nwb_{data}_processed/' for data in data_sources}
-
-fs = s3fs.S3FileSystem(anon=False)
-st.session_state.use_s3 = True
-
try:
st.set_page_config(layout="wide",
page_title='Foraging behavior browser',
@@ -125,43 +106,6 @@
except:
pass
-if 'selected_points' not in st.session_state:
- st.session_state['selected_points'] = []
-
-@st.cache_data(ttl=24*3600)
-def load_data(tables=['sessions'], data_source = 'bonsai'):
- df = {}
- for table in tables:
- file_name = s3_processed_nwb_folder[data_source] + f'df_{table}.pkl'
- if st.session_state.use_s3:
- with fs.open(file_name) as f:
- df[table + '_bonsai'] = pd.read_pickle(f)
- else:
- df[table + '_bonsai'] = pd.read_pickle(file_name)
- return df
-
-def _fetch_img(glob_patterns, crop=None):
- # Fetch the img that first matches the patterns
- for pattern in glob_patterns:
- file = fs.glob(pattern) if st.session_state.use_s3 else glob.glob(pattern)
- if len(file): break
-
- if not len(file):
- return None, None
-
- try:
- if st.session_state.use_s3:
- with fs.open(file[0]) as f:
- img = Image.open(f)
- img = img.crop(crop)
- else:
- img = Image.open(file[0])
- img = img.crop(crop)
- except:
- st.write('File found on S3 but failed to load...')
- return None, None
-
- return img, file[0]
def _user_name_mapper(user_name):
user_mapper = { # tuple of key words --> user name
@@ -178,55 +122,7 @@ def _user_name_mapper(user_name):
return name
else:
return user_name
-
-# @st.cache_data(ttl=24*3600, max_entries=20)
-def show_session_level_img_by_key_and_prefix(key, prefix, column=None, other_patterns=[''], crop=None, caption=True, data_source='bonsai', **kwargs):
- try:
- date_str = key["session_date"].strftime(r'%Y-%m-%d')
- except:
- date_str = key["session_date"].split("T")[0]
- # Convert session_date to 2024-04-01 format
- subject_session_date_str = f"{key['subject_id']}_{date_str}_{key['nwb_suffix']}".split('_0')[0]
- glob_patterns = [s3_processed_nwb_folder[data_source] + f"{subject_session_date_str}/{subject_session_date_str}_{prefix}*"]
-
- img, f_name = _fetch_img(glob_patterns, crop)
-
- _f = st if column is None else column
-
- _f.image(img if img is not None else "https://cdn-icons-png.flaticon.com/512/3585/3585596.png",
- output_format='PNG',
- caption=f_name.split('/')[-1] if caption and f_name else '',
- use_column_width='always',
- **kwargs)
-
- return img
-
-def show_mouse_level_img_by_key_and_prefix(key, prefix, column=None, other_patterns=[''], crop=None, caption=True, **kwargs):
-
- fns = [f'/{key["h2o"]}_*{other_pattern}*' for other_pattern in other_patterns]
- glob_patterns = [cache_mouse_level_fig_folder + f'{prefix}/' + fn for fn in fns]
-
- img, f_name = _fetch_img(glob_patterns, crop)
-
- if img is None: # Use "not_found" image
- glob_patterns = [cache_mouse_level_fig_folder + f'{prefix}/not_found_*{other_pattern}**' for other_pattern in other_patterns]
- img, f_name = _fetch_img(glob_patterns, crop)
-
- _f = st if column is None else column
-
- _f.stream(img if img is not None else "https://cdn-icons-png.flaticon.com/512/3585/3585596.png",
- output_format='PNG',
- #caption=f_name.split('/')[-1] if caption and f_name else '',
- use_column_width='always',
- **kwargs)
-
- return img
-
-# table_mapping = {
-# 'sessions_bonsai': fetch_sessions,
-# 'ephys_units': fetch_ephys_units,
-# }
@st.cache_resource(ttl=24*3600)
def get_pyg_renderer(df, spec="./gw_config.json", **kwargs) -> "StreamlitRenderer":
@@ -287,89 +183,6 @@ def draw_session_plots(df_to_draw_session):
my_bar.progress(int((i + 1) / len(df_to_draw_session) * 100))
-def draw_session_plots_quick_preview(df_to_draw_session):
-
- # Setting up layout for each session
- layout_definition = [[1], # columns in the first row
- [1, 1],
- ]
- draw_types_quick_preview = ['1. Choice history', '2. Logistic regression (Su2022)']
-
- container_session_all_in_one = st.container()
-
- key = df_to_draw_session.to_dict(orient='records')[0]
-
- with container_session_all_in_one:
- try:
- date_str = key["session_date"].strftime('%Y-%m-%d')
- except:
- date_str = key["session_date"].split("T")[0]
-
- st.markdown(f'''
{key["h2o"]}, Session {int(key["session"])}, {date_str} '''
- f'''({key["user_name"]}@{key["data_source"]})''',
- unsafe_allow_html=True)
-
- rows = []
- for row, column_setting in enumerate(layout_definition):
- rows.append(st.columns(column_setting))
-
- for draw_type in draw_types_quick_preview:
- if draw_type not in st.session_state.selected_draw_types: continue # To keep the draw order defined by st.session_state.draw_type_mapper_session_level
- prefix, position, setting = st.session_state.draw_type_mapper_session_level[draw_type]
- this_col = rows[position[0]][position[1]] if len(st.session_state.selected_draw_types) > 1 else rows[0]
- show_session_level_img_by_key_and_prefix(
- key,
- column=this_col,
- prefix=prefix,
- data_source=key["hardware"],
- **setting,
- )
-
-
-def draw_mice_plots(df_to_draw_mice):
-
- # Setting up layout for each session
- layout_definition = [[1], # columns in the first row
- ]
-
- # cols_option = st.columns([3, 0.5, 1])
- container_session_all_in_one = st.container()
-
- with container_session_all_in_one:
- # with st.expander("Expand to see all-in-one plot for selected unit", expanded=True):
-
- if len(df_to_draw_mice):
- st.write(f'Loading selected {len(df_to_draw_mice)} mice...')
- my_bar = st.columns((1, 7))[0].progress(0)
-
- major_cols = st.columns([1] * st.session_state.num_cols_mice)
-
- for i, key in enumerate(df_to_draw_mice.to_dict(orient='records')):
- this_major_col = major_cols[i % st.session_state.num_cols_mice]
-
- # setting up layout for each session
- rows = []
- with this_major_col:
- st.markdown(f'''{key["h2o"]}''',
- unsafe_allow_html=True)
- if len(st.session_state.selected_draw_types_mice) > 1: # more than one types, use the pre-defined layout
- for row, column_setting in enumerate(layout_definition):
- rows.append(this_major_col.columns(column_setting))
- else: # else, put it in the whole column
- rows = this_major_col.columns([1])
- st.markdown("---")
-
- for draw_type in st.session_state.draw_type_mapper_mouse_level:
- if draw_type not in st.session_state.selected_draw_types_mice: continue
- prefix, position, setting = st.session_state.draw_type_mapper_mouse_level[draw_type]
- this_col = rows[position[0]][position[1]] if len(st.session_state.selected_draw_types_mice) > 1 else rows[0]
- show_mouse_level_img_by_key_and_prefix(key,
- column=this_col,
- prefix=prefix,
- **setting)
-
- my_bar.progress(int((i + 1) / len(df_to_draw_mice) * 100))
-
def session_plot_settings(need_click=True):
st.markdown('##### Show plots for individual sessions ')
@@ -957,24 +770,8 @@ def app():
if chosen_id != "tab_auto_train_curriculum":
for _ in range(10): st.write('\n')
st.markdown('---\n##### Debug zone')
- with st.expander('CO processing NWB errors', expanded=False):
- error_file = s3_processed_nwb_folder['bonsai'] + 'error_files.json'
- if fs.exists(error_file):
- with fs.open(error_file) as file:
- st.json(json.load(file))
- else:
- st.write('No NWB error files')
-
- with st.expander('CO Pipeline log', expanded=False):
- with fs.open(s3_processed_nwb_folder['bonsai'] + 'pipeline.log') as file:
- log_content = file.read().decode('utf-8')
- log_content = log_content.replace('\\n', '\n')
- st.text(log_content)
-
- with st.expander('NWB convertion and upload log', expanded=False):
- with fs.open(s3_nwb_folder['bonsai'] + 'bonsai_pipeline.log') as file:
- log_content = file.read().decode('utf-8')
- st.text(log_content)
+ show_debug_info()
+
# Update back to URL
diff --git a/code/pages/3_Playground.py b/code/pages/3_Playground.py
index e25c9e2..f495937 100644
--- a/code/pages/3_Playground.py
+++ b/code/pages/3_Playground.py
@@ -9,23 +9,13 @@
from streamlit_plotly_events import plotly_events
from util.streamlit import add_session_filter, data_selector
+from util.aws_s3 import load_data
ss = st.session_state
fs = s3fs.S3FileSystem(anon=False)
cache_folder = 'aind-behavior-data/foraging_nwb_bonsai_processed/'
-@st.cache_data(ttl=24*3600)
-def load_data(tables=['sessions']):
- df = {}
- for table in tables:
- file_name = cache_folder + f'df_{table}.pkl'
- if st.session_state.use_s3:
- with fs.open(file_name) as f:
- df[table + '_bonsai'] = pd.read_pickle(f)
- else:
- df[table + '_bonsai'] = pd.read_pickle(file_name)
- return df
def app():
diff --git a/code/util/aws_s3.py b/code/util/aws_s3.py
new file mode 100644
index 0000000..2cd5997
--- /dev/null
+++ b/code/util/aws_s3.py
@@ -0,0 +1,133 @@
+from PIL import Image
+import glob
+import json
+
+import s3fs
+import pandas as pd
+import streamlit as st
+
+# --------------------------------------
+data_sources = ['bonsai', 'bpod']
+
+s3_nwb_folder = {data: f'aind-behavior-data/foraging_nwb_{data}/' for data in data_sources}
+s3_processed_nwb_folder = {data: f'aind-behavior-data/foraging_nwb_{data}_processed/' for data in data_sources}
+# --------------------------------------
+
+fs = s3fs.S3FileSystem(anon=False)
+
+
+if 'selected_points' not in st.session_state:
+ st.session_state['selected_points'] = []
+
+@st.cache_data(ttl=24*3600)
+def load_data(tables=['sessions'], data_source = 'bonsai'):
+ df = {}
+ for table in tables:
+ file_name = s3_processed_nwb_folder[data_source] + f'df_{table}.pkl'
+ with fs.open(file_name) as f:
+ df[table + '_bonsai'] = pd.read_pickle(f)
+ return df
+
+
+def draw_session_plots_quick_preview(df_to_draw_session):
+
+ # Setting up layout for each session
+ layout_definition = [[1], # columns in the first row
+ [1, 1],
+ ]
+ draw_types_quick_preview = ['1. Choice history', '2. Logistic regression (Su2022)']
+
+ container_session_all_in_one = st.container()
+
+ key = df_to_draw_session.to_dict(orient='records')[0]
+
+ with container_session_all_in_one:
+ try:
+ date_str = key["session_date"].strftime('%Y-%m-%d')
+ except:
+ date_str = key["session_date"].split("T")[0]
+
+ st.markdown(f'''{key["h2o"]}, Session {int(key["session"])}, {date_str} '''
+ f'''({key["user_name"]}@{key["data_source"]})''',
+ unsafe_allow_html=True)
+
+ rows = []
+ for row, column_setting in enumerate(layout_definition):
+ rows.append(st.columns(column_setting))
+
+ for draw_type in draw_types_quick_preview:
+ if draw_type not in st.session_state.selected_draw_types: continue # To keep the draw order defined by st.session_state.draw_type_mapper_session_level
+ prefix, position, setting = st.session_state.draw_type_mapper_session_level[draw_type]
+ this_col = rows[position[0]][position[1]] if len(st.session_state.selected_draw_types) > 1 else rows[0]
+ show_session_level_img_by_key_and_prefix(
+ key,
+ column=this_col,
+ prefix=prefix,
+ data_source=key["hardware"],
+ **setting,
+ )
+
+
+# @st.cache_data(ttl=24*3600, max_entries=20)
+def show_session_level_img_by_key_and_prefix(key, prefix, column=None, other_patterns=[''], crop=None, caption=True, data_source='bonsai', **kwargs):
+ try:
+ date_str = key["session_date"].strftime(r'%Y-%m-%d')
+ except:
+ date_str = key["session_date"].split("T")[0]
+
+ # Convert session_date to 2024-04-01 format
+ subject_session_date_str = f"{key['subject_id']}_{date_str}_{key['nwb_suffix']}".split('_0')[0]
+ glob_patterns = [s3_processed_nwb_folder[data_source] + f"{subject_session_date_str}/{subject_session_date_str}_{prefix}*"]
+
+ img, f_name = _fetch_img(glob_patterns, crop)
+
+ _f = st if column is None else column
+
+ _f.image(img if img is not None else "https://cdn-icons-png.flaticon.com/512/3585/3585596.png",
+ output_format='PNG',
+ caption=f_name.split('/')[-1] if caption and f_name else '',
+ use_column_width='always',
+ **kwargs)
+
+ return img
+
+
+def _fetch_img(glob_patterns, crop=None):
+ # Fetch the img that first matches the patterns
+ for pattern in glob_patterns:
+ file = fs.glob(pattern)
+ if len(file): break
+
+ if not len(file):
+ return None, None
+
+ try:
+ with fs.open(file[0]) as f:
+ img = Image.open(f)
+ img = img.crop(crop)
+ except:
+ st.write('File found on S3 but failed to load...')
+ return None, None
+
+ return img, file[0]
+
+
+def show_debug_info():
+ with st.expander('CO processing NWB errors', expanded=False):
+ error_file = s3_processed_nwb_folder['bonsai'] + 'error_files.json'
+ if fs.exists(error_file):
+ with fs.open(error_file) as file:
+ st.json(json.load(file))
+ else:
+ st.write('No NWB error files')
+
+ with st.expander('CO Pipeline log', expanded=False):
+ with fs.open(s3_processed_nwb_folder['bonsai'] + 'pipeline.log') as file:
+ log_content = file.read().decode('utf-8')
+ log_content = log_content.replace('\\n', '\n')
+ st.text(log_content)
+
+ with st.expander('NWB convertion and upload log', expanded=False):
+ with fs.open(s3_nwb_folder['bonsai'] + 'bonsai_pipeline.log') as file:
+ log_content = file.read().decode('utf-8')
+ st.text(log_content)
\ No newline at end of file