From 1fd149dd8f528771a3a4032eab3e2216eccf1e8f Mon Sep 17 00:00:00 2001 From: hh_itx_win10 Date: Mon, 8 Apr 2024 22:45:28 -0700 Subject: [PATCH] refactor: move s3 related stuff to util.aws_s3 --- code/Home.py | 225 ++----------------------------------- code/pages/3_Playground.py | 12 +- code/util/aws_s3.py | 133 ++++++++++++++++++++++ 3 files changed, 145 insertions(+), 225 deletions(-) create mode 100644 code/util/aws_s3.py diff --git a/code/Home.py b/code/Home.py index 766b1c4..02e67e4 100644 --- a/code/Home.py +++ b/code/Home.py @@ -15,39 +15,28 @@ # %% import pandas as pd import streamlit as st -from pathlib import Path -import glob -import matplotlib.pyplot as plt import numpy as np -from datetime import datetime -import s3fs import os -import plotly.express as px -import plotly -import plotly.graph_objects as go -import statsmodels.api as sm -import json - -from PIL import Image, ImageColor -import streamlit.components.v1 as components + import streamlit_nested_layout from streamlit_plotly_events import plotly_events from pygwalker.api.streamlit import StreamlitRenderer, init_streamlit_comm +import extra_streamlit_components as stx -# To suppress the warning that I set the default value of a widget and also set it in the session state -from streamlit.elements.utils import _shown_default_value_warning -_shown_default_value_warning = False - -from util.streamlit import (filter_dataframe, aggrid_interactive_table_session, +from util.streamlit import (aggrid_interactive_table_session, aggrid_interactive_table_curriculum, add_session_filter, data_selector, add_xy_selector, add_xy_setting, add_auto_train_manager, add_dot_property_mapper, _plot_population_x_y) +from util.aws_s3 import ( + load_data, + draw_session_plots_quick_preview, + show_session_level_img_by_key_and_prefix, + show_debug_info, +) from util.url_query_helper import ( sync_widget_with_query, slider_wrapper_for_url_query, checkbox_wrapper_for_url_query ) -import extra_streamlit_components as stx - from aind_auto_train.curriculum_manager import CurriculumManager from aind_auto_train.auto_train_manager import DynamicForagingAutoTrainManager @@ -105,14 +94,6 @@ } -data_sources = ['bonsai', 'bpod'] - -s3_nwb_folder = {data: f'aind-behavior-data/foraging_nwb_{data}/' for data in data_sources} -s3_processed_nwb_folder = {data: f'aind-behavior-data/foraging_nwb_{data}_processed/' for data in data_sources} - -fs = s3fs.S3FileSystem(anon=False) -st.session_state.use_s3 = True - try: st.set_page_config(layout="wide", page_title='Foraging behavior browser', @@ -125,43 +106,6 @@ except: pass -if 'selected_points' not in st.session_state: - st.session_state['selected_points'] = [] - -@st.cache_data(ttl=24*3600) -def load_data(tables=['sessions'], data_source = 'bonsai'): - df = {} - for table in tables: - file_name = s3_processed_nwb_folder[data_source] + f'df_{table}.pkl' - if st.session_state.use_s3: - with fs.open(file_name) as f: - df[table + '_bonsai'] = pd.read_pickle(f) - else: - df[table + '_bonsai'] = pd.read_pickle(file_name) - return df - -def _fetch_img(glob_patterns, crop=None): - # Fetch the img that first matches the patterns - for pattern in glob_patterns: - file = fs.glob(pattern) if st.session_state.use_s3 else glob.glob(pattern) - if len(file): break - - if not len(file): - return None, None - - try: - if st.session_state.use_s3: - with fs.open(file[0]) as f: - img = Image.open(f) - img = img.crop(crop) - else: - img = Image.open(file[0]) - img = img.crop(crop) - except: - st.write('File found on S3 but failed to load...') - return None, None - - return img, file[0] def _user_name_mapper(user_name): user_mapper = { # tuple of key words --> user name @@ -178,55 +122,7 @@ def _user_name_mapper(user_name): return name else: return user_name - -# @st.cache_data(ttl=24*3600, max_entries=20) -def show_session_level_img_by_key_and_prefix(key, prefix, column=None, other_patterns=[''], crop=None, caption=True, data_source='bonsai', **kwargs): - try: - date_str = key["session_date"].strftime(r'%Y-%m-%d') - except: - date_str = key["session_date"].split("T")[0] - # Convert session_date to 2024-04-01 format - subject_session_date_str = f"{key['subject_id']}_{date_str}_{key['nwb_suffix']}".split('_0')[0] - glob_patterns = [s3_processed_nwb_folder[data_source] + f"{subject_session_date_str}/{subject_session_date_str}_{prefix}*"] - - img, f_name = _fetch_img(glob_patterns, crop) - - _f = st if column is None else column - - _f.image(img if img is not None else "https://cdn-icons-png.flaticon.com/512/3585/3585596.png", - output_format='PNG', - caption=f_name.split('/')[-1] if caption and f_name else '', - use_column_width='always', - **kwargs) - - return img - -def show_mouse_level_img_by_key_and_prefix(key, prefix, column=None, other_patterns=[''], crop=None, caption=True, **kwargs): - - fns = [f'/{key["h2o"]}_*{other_pattern}*' for other_pattern in other_patterns] - glob_patterns = [cache_mouse_level_fig_folder + f'{prefix}/' + fn for fn in fns] - - img, f_name = _fetch_img(glob_patterns, crop) - - if img is None: # Use "not_found" image - glob_patterns = [cache_mouse_level_fig_folder + f'{prefix}/not_found_*{other_pattern}**' for other_pattern in other_patterns] - img, f_name = _fetch_img(glob_patterns, crop) - - _f = st if column is None else column - - _f.stream(img if img is not None else "https://cdn-icons-png.flaticon.com/512/3585/3585596.png", - output_format='PNG', - #caption=f_name.split('/')[-1] if caption and f_name else '', - use_column_width='always', - **kwargs) - - return img - -# table_mapping = { -# 'sessions_bonsai': fetch_sessions, -# 'ephys_units': fetch_ephys_units, -# } @st.cache_resource(ttl=24*3600) def get_pyg_renderer(df, spec="./gw_config.json", **kwargs) -> "StreamlitRenderer": @@ -287,89 +183,6 @@ def draw_session_plots(df_to_draw_session): my_bar.progress(int((i + 1) / len(df_to_draw_session) * 100)) -def draw_session_plots_quick_preview(df_to_draw_session): - - # Setting up layout for each session - layout_definition = [[1], # columns in the first row - [1, 1], - ] - draw_types_quick_preview = ['1. Choice history', '2. Logistic regression (Su2022)'] - - container_session_all_in_one = st.container() - - key = df_to_draw_session.to_dict(orient='records')[0] - - with container_session_all_in_one: - try: - date_str = key["session_date"].strftime('%Y-%m-%d') - except: - date_str = key["session_date"].split("T")[0] - - st.markdown(f'''
{key["h2o"]}, Session {int(key["session"])}, {date_str} ''' - f'''({key["user_name"]}@{key["data_source"]})''', - unsafe_allow_html=True) - - rows = [] - for row, column_setting in enumerate(layout_definition): - rows.append(st.columns(column_setting)) - - for draw_type in draw_types_quick_preview: - if draw_type not in st.session_state.selected_draw_types: continue # To keep the draw order defined by st.session_state.draw_type_mapper_session_level - prefix, position, setting = st.session_state.draw_type_mapper_session_level[draw_type] - this_col = rows[position[0]][position[1]] if len(st.session_state.selected_draw_types) > 1 else rows[0] - show_session_level_img_by_key_and_prefix( - key, - column=this_col, - prefix=prefix, - data_source=key["hardware"], - **setting, - ) - - -def draw_mice_plots(df_to_draw_mice): - - # Setting up layout for each session - layout_definition = [[1], # columns in the first row - ] - - # cols_option = st.columns([3, 0.5, 1]) - container_session_all_in_one = st.container() - - with container_session_all_in_one: - # with st.expander("Expand to see all-in-one plot for selected unit", expanded=True): - - if len(df_to_draw_mice): - st.write(f'Loading selected {len(df_to_draw_mice)} mice...') - my_bar = st.columns((1, 7))[0].progress(0) - - major_cols = st.columns([1] * st.session_state.num_cols_mice) - - for i, key in enumerate(df_to_draw_mice.to_dict(orient='records')): - this_major_col = major_cols[i % st.session_state.num_cols_mice] - - # setting up layout for each session - rows = [] - with this_major_col: - st.markdown(f'''

{key["h2o"]}''', - unsafe_allow_html=True) - if len(st.session_state.selected_draw_types_mice) > 1: # more than one types, use the pre-defined layout - for row, column_setting in enumerate(layout_definition): - rows.append(this_major_col.columns(column_setting)) - else: # else, put it in the whole column - rows = this_major_col.columns([1]) - st.markdown("---") - - for draw_type in st.session_state.draw_type_mapper_mouse_level: - if draw_type not in st.session_state.selected_draw_types_mice: continue - prefix, position, setting = st.session_state.draw_type_mapper_mouse_level[draw_type] - this_col = rows[position[0]][position[1]] if len(st.session_state.selected_draw_types_mice) > 1 else rows[0] - show_mouse_level_img_by_key_and_prefix(key, - column=this_col, - prefix=prefix, - **setting) - - my_bar.progress(int((i + 1) / len(df_to_draw_mice) * 100)) - def session_plot_settings(need_click=True): st.markdown('##### Show plots for individual sessions ') @@ -957,24 +770,8 @@ def app(): if chosen_id != "tab_auto_train_curriculum": for _ in range(10): st.write('\n') st.markdown('---\n##### Debug zone') - with st.expander('CO processing NWB errors', expanded=False): - error_file = s3_processed_nwb_folder['bonsai'] + 'error_files.json' - if fs.exists(error_file): - with fs.open(error_file) as file: - st.json(json.load(file)) - else: - st.write('No NWB error files') - - with st.expander('CO Pipeline log', expanded=False): - with fs.open(s3_processed_nwb_folder['bonsai'] + 'pipeline.log') as file: - log_content = file.read().decode('utf-8') - log_content = log_content.replace('\\n', '\n') - st.text(log_content) - - with st.expander('NWB convertion and upload log', expanded=False): - with fs.open(s3_nwb_folder['bonsai'] + 'bonsai_pipeline.log') as file: - log_content = file.read().decode('utf-8') - st.text(log_content) + show_debug_info() + # Update back to URL diff --git a/code/pages/3_Playground.py b/code/pages/3_Playground.py index e25c9e2..f495937 100644 --- a/code/pages/3_Playground.py +++ b/code/pages/3_Playground.py @@ -9,23 +9,13 @@ from streamlit_plotly_events import plotly_events from util.streamlit import add_session_filter, data_selector +from util.aws_s3 import load_data ss = st.session_state fs = s3fs.S3FileSystem(anon=False) cache_folder = 'aind-behavior-data/foraging_nwb_bonsai_processed/' -@st.cache_data(ttl=24*3600) -def load_data(tables=['sessions']): - df = {} - for table in tables: - file_name = cache_folder + f'df_{table}.pkl' - if st.session_state.use_s3: - with fs.open(file_name) as f: - df[table + '_bonsai'] = pd.read_pickle(f) - else: - df[table + '_bonsai'] = pd.read_pickle(file_name) - return df def app(): diff --git a/code/util/aws_s3.py b/code/util/aws_s3.py new file mode 100644 index 0000000..2cd5997 --- /dev/null +++ b/code/util/aws_s3.py @@ -0,0 +1,133 @@ +from PIL import Image +import glob +import json + +import s3fs +import pandas as pd +import streamlit as st + +# -------------------------------------- +data_sources = ['bonsai', 'bpod'] + +s3_nwb_folder = {data: f'aind-behavior-data/foraging_nwb_{data}/' for data in data_sources} +s3_processed_nwb_folder = {data: f'aind-behavior-data/foraging_nwb_{data}_processed/' for data in data_sources} +# -------------------------------------- + +fs = s3fs.S3FileSystem(anon=False) + + +if 'selected_points' not in st.session_state: + st.session_state['selected_points'] = [] + +@st.cache_data(ttl=24*3600) +def load_data(tables=['sessions'], data_source = 'bonsai'): + df = {} + for table in tables: + file_name = s3_processed_nwb_folder[data_source] + f'df_{table}.pkl' + with fs.open(file_name) as f: + df[table + '_bonsai'] = pd.read_pickle(f) + return df + + +def draw_session_plots_quick_preview(df_to_draw_session): + + # Setting up layout for each session + layout_definition = [[1], # columns in the first row + [1, 1], + ] + draw_types_quick_preview = ['1. Choice history', '2. Logistic regression (Su2022)'] + + container_session_all_in_one = st.container() + + key = df_to_draw_session.to_dict(orient='records')[0] + + with container_session_all_in_one: + try: + date_str = key["session_date"].strftime('%Y-%m-%d') + except: + date_str = key["session_date"].split("T")[0] + + st.markdown(f'''

{key["h2o"]}, Session {int(key["session"])}, {date_str} ''' + f'''({key["user_name"]}@{key["data_source"]})''', + unsafe_allow_html=True) + + rows = [] + for row, column_setting in enumerate(layout_definition): + rows.append(st.columns(column_setting)) + + for draw_type in draw_types_quick_preview: + if draw_type not in st.session_state.selected_draw_types: continue # To keep the draw order defined by st.session_state.draw_type_mapper_session_level + prefix, position, setting = st.session_state.draw_type_mapper_session_level[draw_type] + this_col = rows[position[0]][position[1]] if len(st.session_state.selected_draw_types) > 1 else rows[0] + show_session_level_img_by_key_and_prefix( + key, + column=this_col, + prefix=prefix, + data_source=key["hardware"], + **setting, + ) + + +# @st.cache_data(ttl=24*3600, max_entries=20) +def show_session_level_img_by_key_and_prefix(key, prefix, column=None, other_patterns=[''], crop=None, caption=True, data_source='bonsai', **kwargs): + try: + date_str = key["session_date"].strftime(r'%Y-%m-%d') + except: + date_str = key["session_date"].split("T")[0] + + # Convert session_date to 2024-04-01 format + subject_session_date_str = f"{key['subject_id']}_{date_str}_{key['nwb_suffix']}".split('_0')[0] + glob_patterns = [s3_processed_nwb_folder[data_source] + f"{subject_session_date_str}/{subject_session_date_str}_{prefix}*"] + + img, f_name = _fetch_img(glob_patterns, crop) + + _f = st if column is None else column + + _f.image(img if img is not None else "https://cdn-icons-png.flaticon.com/512/3585/3585596.png", + output_format='PNG', + caption=f_name.split('/')[-1] if caption and f_name else '', + use_column_width='always', + **kwargs) + + return img + + +def _fetch_img(glob_patterns, crop=None): + # Fetch the img that first matches the patterns + for pattern in glob_patterns: + file = fs.glob(pattern) + if len(file): break + + if not len(file): + return None, None + + try: + with fs.open(file[0]) as f: + img = Image.open(f) + img = img.crop(crop) + except: + st.write('File found on S3 but failed to load...') + return None, None + + return img, file[0] + + +def show_debug_info(): + with st.expander('CO processing NWB errors', expanded=False): + error_file = s3_processed_nwb_folder['bonsai'] + 'error_files.json' + if fs.exists(error_file): + with fs.open(error_file) as file: + st.json(json.load(file)) + else: + st.write('No NWB error files') + + with st.expander('CO Pipeline log', expanded=False): + with fs.open(s3_processed_nwb_folder['bonsai'] + 'pipeline.log') as file: + log_content = file.read().decode('utf-8') + log_content = log_content.replace('\\n', '\n') + st.text(log_content) + + with st.expander('NWB convertion and upload log', expanded=False): + with fs.open(s3_nwb_folder['bonsai'] + 'bonsai_pipeline.log') as file: + log_content = file.read().decode('utf-8') + st.text(log_content) \ No newline at end of file