Skip to content

Commit

Permalink
Merge branch 'han_refactor'
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhou committed Apr 9, 2024
2 parents cb2f4b3 + 1fd149d commit 1230283
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 225 deletions.
225 changes: 11 additions & 214 deletions code/Home.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,39 +15,28 @@
# %%
import pandas as pd
import streamlit as st
from pathlib import Path
import glob
import matplotlib.pyplot as plt
import numpy as np
from datetime import datetime
import s3fs
import os
import plotly.express as px
import plotly
import plotly.graph_objects as go
import statsmodels.api as sm
import json

from PIL import Image, ImageColor
import streamlit.components.v1 as components

import streamlit_nested_layout
from streamlit_plotly_events import plotly_events
from pygwalker.api.streamlit import StreamlitRenderer, init_streamlit_comm
import extra_streamlit_components as stx

# To suppress the warning that I set the default value of a widget and also set it in the session state
from streamlit.elements.utils import _shown_default_value_warning
_shown_default_value_warning = False

from util.streamlit import (filter_dataframe, aggrid_interactive_table_session,
from util.streamlit import (aggrid_interactive_table_session,
aggrid_interactive_table_curriculum, add_session_filter, data_selector,
add_xy_selector, add_xy_setting, add_auto_train_manager, add_dot_property_mapper,
_plot_population_x_y)
from util.aws_s3 import (
load_data,
draw_session_plots_quick_preview,
show_session_level_img_by_key_and_prefix,
show_debug_info,
)
from util.url_query_helper import (
sync_widget_with_query, slider_wrapper_for_url_query, checkbox_wrapper_for_url_query
)

import extra_streamlit_components as stx

from aind_auto_train.curriculum_manager import CurriculumManager
from aind_auto_train.auto_train_manager import DynamicForagingAutoTrainManager

Expand Down Expand Up @@ -105,14 +94,6 @@
}


data_sources = ['bonsai', 'bpod']

s3_nwb_folder = {data: f'aind-behavior-data/foraging_nwb_{data}/' for data in data_sources}
s3_processed_nwb_folder = {data: f'aind-behavior-data/foraging_nwb_{data}_processed/' for data in data_sources}

fs = s3fs.S3FileSystem(anon=False)
st.session_state.use_s3 = True

try:
st.set_page_config(layout="wide",
page_title='Foraging behavior browser',
Expand All @@ -125,43 +106,6 @@
except:
pass

if 'selected_points' not in st.session_state:
st.session_state['selected_points'] = []

@st.cache_data(ttl=24*3600)
def load_data(tables=['sessions'], data_source = 'bonsai'):
df = {}
for table in tables:
file_name = s3_processed_nwb_folder[data_source] + f'df_{table}.pkl'
if st.session_state.use_s3:
with fs.open(file_name) as f:
df[table + '_bonsai'] = pd.read_pickle(f)
else:
df[table + '_bonsai'] = pd.read_pickle(file_name)
return df

def _fetch_img(glob_patterns, crop=None):
# Fetch the img that first matches the patterns
for pattern in glob_patterns:
file = fs.glob(pattern) if st.session_state.use_s3 else glob.glob(pattern)
if len(file): break

if not len(file):
return None, None

try:
if st.session_state.use_s3:
with fs.open(file[0]) as f:
img = Image.open(f)
img = img.crop(crop)
else:
img = Image.open(file[0])
img = img.crop(crop)
except:
st.write('File found on S3 but failed to load...')
return None, None

return img, file[0]

def _user_name_mapper(user_name):
user_mapper = { # tuple of key words --> user name
Expand All @@ -178,55 +122,7 @@ def _user_name_mapper(user_name):
return name
else:
return user_name

# @st.cache_data(ttl=24*3600, max_entries=20)
def show_session_level_img_by_key_and_prefix(key, prefix, column=None, other_patterns=[''], crop=None, caption=True, data_source='bonsai', **kwargs):
try:
date_str = key["session_date"].strftime(r'%Y-%m-%d')
except:
date_str = key["session_date"].split("T")[0]

# Convert session_date to 2024-04-01 format
subject_session_date_str = f"{key['subject_id']}_{date_str}_{key['nwb_suffix']}".split('_0')[0]
glob_patterns = [s3_processed_nwb_folder[data_source] + f"{subject_session_date_str}/{subject_session_date_str}_{prefix}*"]

img, f_name = _fetch_img(glob_patterns, crop)

_f = st if column is None else column

_f.image(img if img is not None else "https://cdn-icons-png.flaticon.com/512/3585/3585596.png",
output_format='PNG',
caption=f_name.split('/')[-1] if caption and f_name else '',
use_column_width='always',
**kwargs)

return img

def show_mouse_level_img_by_key_and_prefix(key, prefix, column=None, other_patterns=[''], crop=None, caption=True, **kwargs):

fns = [f'/{key["h2o"]}_*{other_pattern}*' for other_pattern in other_patterns]
glob_patterns = [cache_mouse_level_fig_folder + f'{prefix}/' + fn for fn in fns]

img, f_name = _fetch_img(glob_patterns, crop)

if img is None: # Use "not_found" image
glob_patterns = [cache_mouse_level_fig_folder + f'{prefix}/not_found_*{other_pattern}**' for other_pattern in other_patterns]
img, f_name = _fetch_img(glob_patterns, crop)

_f = st if column is None else column

_f.stream(img if img is not None else "https://cdn-icons-png.flaticon.com/512/3585/3585596.png",
output_format='PNG',
#caption=f_name.split('/')[-1] if caption and f_name else '',
use_column_width='always',
**kwargs)

return img

# table_mapping = {
# 'sessions_bonsai': fetch_sessions,
# 'ephys_units': fetch_ephys_units,
# }

@st.cache_resource(ttl=24*3600)
def get_pyg_renderer(df, spec="./gw_config.json", **kwargs) -> "StreamlitRenderer":
Expand Down Expand Up @@ -287,89 +183,6 @@ def draw_session_plots(df_to_draw_session):

my_bar.progress(int((i + 1) / len(df_to_draw_session) * 100))

def draw_session_plots_quick_preview(df_to_draw_session):

# Setting up layout for each session
layout_definition = [[1], # columns in the first row
[1, 1],
]
draw_types_quick_preview = ['1. Choice history', '2. Logistic regression (Su2022)']

container_session_all_in_one = st.container()

key = df_to_draw_session.to_dict(orient='records')[0]

with container_session_all_in_one:
try:
date_str = key["session_date"].strftime('%Y-%m-%d')
except:
date_str = key["session_date"].split("T")[0]

st.markdown(f'''<h5 style='text-align: center; color: orange;'>{key["h2o"]}, Session {int(key["session"])}, {date_str} '''
f'''({key["user_name"]}@{key["data_source"]})''',
unsafe_allow_html=True)

rows = []
for row, column_setting in enumerate(layout_definition):
rows.append(st.columns(column_setting))

for draw_type in draw_types_quick_preview:
if draw_type not in st.session_state.selected_draw_types: continue # To keep the draw order defined by st.session_state.draw_type_mapper_session_level
prefix, position, setting = st.session_state.draw_type_mapper_session_level[draw_type]
this_col = rows[position[0]][position[1]] if len(st.session_state.selected_draw_types) > 1 else rows[0]
show_session_level_img_by_key_and_prefix(
key,
column=this_col,
prefix=prefix,
data_source=key["hardware"],
**setting,
)


def draw_mice_plots(df_to_draw_mice):

# Setting up layout for each session
layout_definition = [[1], # columns in the first row
]

# cols_option = st.columns([3, 0.5, 1])
container_session_all_in_one = st.container()

with container_session_all_in_one:
# with st.expander("Expand to see all-in-one plot for selected unit", expanded=True):

if len(df_to_draw_mice):
st.write(f'Loading selected {len(df_to_draw_mice)} mice...')
my_bar = st.columns((1, 7))[0].progress(0)

major_cols = st.columns([1] * st.session_state.num_cols_mice)

for i, key in enumerate(df_to_draw_mice.to_dict(orient='records')):
this_major_col = major_cols[i % st.session_state.num_cols_mice]

# setting up layout for each session
rows = []
with this_major_col:
st.markdown(f'''<h3 style='text-align: center; color: orange;'>{key["h2o"]}''',
unsafe_allow_html=True)
if len(st.session_state.selected_draw_types_mice) > 1: # more than one types, use the pre-defined layout
for row, column_setting in enumerate(layout_definition):
rows.append(this_major_col.columns(column_setting))
else: # else, put it in the whole column
rows = this_major_col.columns([1])
st.markdown("---")

for draw_type in st.session_state.draw_type_mapper_mouse_level:
if draw_type not in st.session_state.selected_draw_types_mice: continue
prefix, position, setting = st.session_state.draw_type_mapper_mouse_level[draw_type]
this_col = rows[position[0]][position[1]] if len(st.session_state.selected_draw_types_mice) > 1 else rows[0]
show_mouse_level_img_by_key_and_prefix(key,
column=this_col,
prefix=prefix,
**setting)

my_bar.progress(int((i + 1) / len(df_to_draw_mice) * 100))


def session_plot_settings(need_click=True):
st.markdown('##### Show plots for individual sessions ')
Expand Down Expand Up @@ -957,24 +770,8 @@ def app():
if chosen_id != "tab_auto_train_curriculum":
for _ in range(10): st.write('\n')
st.markdown('---\n##### Debug zone')
with st.expander('CO processing NWB errors', expanded=False):
error_file = s3_processed_nwb_folder['bonsai'] + 'error_files.json'
if fs.exists(error_file):
with fs.open(error_file) as file:
st.json(json.load(file))
else:
st.write('No NWB error files')

with st.expander('CO Pipeline log', expanded=False):
with fs.open(s3_processed_nwb_folder['bonsai'] + 'pipeline.log') as file:
log_content = file.read().decode('utf-8')
log_content = log_content.replace('\\n', '\n')
st.text(log_content)

with st.expander('NWB convertion and upload log', expanded=False):
with fs.open(s3_nwb_folder['bonsai'] + 'bonsai_pipeline.log') as file:
log_content = file.read().decode('utf-8')
st.text(log_content)
show_debug_info()



# Update back to URL
Expand Down
12 changes: 1 addition & 11 deletions code/pages/3_Playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,13 @@
from streamlit_plotly_events import plotly_events

from util.streamlit import add_session_filter, data_selector
from util.aws_s3 import load_data

ss = st.session_state

fs = s3fs.S3FileSystem(anon=False)
cache_folder = 'aind-behavior-data/foraging_nwb_bonsai_processed/'

@st.cache_data(ttl=24*3600)
def load_data(tables=['sessions']):
df = {}
for table in tables:
file_name = cache_folder + f'df_{table}.pkl'
if st.session_state.use_s3:
with fs.open(file_name) as f:
df[table + '_bonsai'] = pd.read_pickle(f)
else:
df[table + '_bonsai'] = pd.read_pickle(file_name)
return df

def app():

Expand Down
Loading

0 comments on commit 1230283

Please sign in to comment.